3v324v23 commited on
Commit
a09a133
·
1 Parent(s): c138d1f

working version 1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +122 -0
  2. configs/_base_/datasets/parking_instance.py +48 -0
  3. configs/_base_/datasets/parking_instance_coco.py +49 -0
  4. configs/_base_/datasets/people_real_coco.py +49 -0
  5. configs/_base_/datasets/walt_people.py +49 -0
  6. configs/_base_/datasets/walt_vehicle.py +49 -0
  7. configs/_base_/default_runtime.py +16 -0
  8. configs/_base_/models/mask_rcnn_swin_fpn.py +127 -0
  9. configs/_base_/models/occ_mask_rcnn_swin_fpn.py +127 -0
  10. configs/_base_/schedules/schedule_1x.py +11 -0
  11. configs/walt/walt_people.py +80 -0
  12. configs/walt/walt_vehicle.py +80 -0
  13. docker/Dockerfile +52 -0
  14. github_vis/cwalt.gif +0 -0
  15. github_vis/vis_cars.gif +0 -0
  16. github_vis/vis_people.gif +0 -0
  17. mmcv_custom/__init__.py +5 -0
  18. mmcv_custom/checkpoint.py +500 -0
  19. mmcv_custom/runner/__init__.py +8 -0
  20. mmcv_custom/runner/checkpoint.py +85 -0
  21. mmcv_custom/runner/epoch_based_runner.py +104 -0
  22. mmdet/__init__.py +28 -0
  23. mmdet/apis/__init__.py +10 -0
  24. mmdet/apis/inference.py +217 -0
  25. mmdet/apis/test.py +189 -0
  26. mmdet/apis/train.py +185 -0
  27. mmdet/core/__init__.py +7 -0
  28. mmdet/core/anchor/__init__.py +11 -0
  29. mmdet/core/anchor/anchor_generator.py +727 -0
  30. mmdet/core/anchor/builder.py +7 -0
  31. mmdet/core/anchor/point_generator.py +37 -0
  32. mmdet/core/anchor/utils.py +71 -0
  33. mmdet/core/bbox/__init__.py +27 -0
  34. mmdet/core/bbox/assigners/__init__.py +16 -0
  35. mmdet/core/bbox/assigners/approx_max_iou_assigner.py +145 -0
  36. mmdet/core/bbox/assigners/assign_result.py +204 -0
  37. mmdet/core/bbox/assigners/atss_assigner.py +178 -0
  38. mmdet/core/bbox/assigners/base_assigner.py +9 -0
  39. mmdet/core/bbox/assigners/center_region_assigner.py +335 -0
  40. mmdet/core/bbox/assigners/grid_assigner.py +155 -0
  41. mmdet/core/bbox/assigners/hungarian_assigner.py +145 -0
  42. mmdet/core/bbox/assigners/max_iou_assigner.py +212 -0
  43. mmdet/core/bbox/assigners/point_assigner.py +133 -0
  44. mmdet/core/bbox/assigners/region_assigner.py +221 -0
  45. mmdet/core/bbox/builder.py +20 -0
  46. mmdet/core/bbox/coder/__init__.py +13 -0
  47. mmdet/core/bbox/coder/base_bbox_coder.py +17 -0
  48. mmdet/core/bbox/coder/bucketing_bbox_coder.py +350 -0
  49. mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +237 -0
  50. mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py +215 -0
.gitignore ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ data/
107
+ data
108
+ .vscode
109
+ .idea
110
+ .DS_Store
111
+
112
+ # custom
113
+ *.pkl
114
+ *.pkl.json
115
+ *.log.json
116
+ work_dirs/
117
+
118
+ # Pytorch
119
+ *.pth
120
+ *.py~
121
+ *.sh~
122
+
configs/_base_/datasets/parking_instance.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'ParkingDataset'
2
+ data_root = 'data/parking/'
3
+ img_norm_cfg = dict(
4
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5
+ train_pipeline = [
6
+ dict(type='LoadImageFromFile'),
7
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
8
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
9
+ dict(type='RandomFlip', flip_ratio=0.5),
10
+ dict(type='Normalize', **img_norm_cfg),
11
+ dict(type='Pad', size_divisor=32),
12
+ dict(type='DefaultFormatBundle'),
13
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d','gt_bboxes_3d_proj']),
14
+ ]
15
+ test_pipeline = [
16
+ dict(type='LoadImageFromFile'),
17
+ dict(
18
+ type='MultiScaleFlipAug',
19
+ img_scale=(1333, 800),
20
+ flip=False,
21
+ transforms=[
22
+ dict(type='Resize', keep_ratio=True),
23
+ dict(type='RandomFlip'),
24
+ dict(type='Normalize', **img_norm_cfg),
25
+ dict(type='Pad', size_divisor=32),
26
+ dict(type='ImageToTensor', keys=['img']),
27
+ dict(type='Collect', keys=['img']),
28
+ ])
29
+ ]
30
+ data = dict(
31
+ samples_per_gpu=1,
32
+ workers_per_gpu=1,
33
+ train=dict(
34
+ type=dataset_type,
35
+ ann_file=data_root + 'GT_data/',
36
+ img_prefix=data_root + 'images/',
37
+ pipeline=train_pipeline),
38
+ val=dict(
39
+ type=dataset_type,
40
+ ann_file=data_root + 'GT_data/',
41
+ img_prefix=data_root + 'images/',
42
+ pipeline=test_pipeline),
43
+ test=dict(
44
+ type=dataset_type,
45
+ ann_file=data_root + 'GT_data/',
46
+ img_prefix=data_root + 'images/',
47
+ pipeline=test_pipeline))
48
+ evaluation = dict(metric=['bbox'])#, 'segm'])
configs/_base_/datasets/parking_instance_coco.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'ParkingCocoDataset'
2
+ data_root = 'data/parking/'
3
+ data_root_test = 'data/parking_highres/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=6,
33
+ workers_per_gpu=6,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + 'GT_data/',
37
+ img_prefix=data_root + 'images/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + 'GT_data/',
42
+ img_prefix=data_root_test + 'images',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + 'GT_data/',
47
+ img_prefix=data_root_test + 'images',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/people_real_coco.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=8,
33
+ workers_per_gpu=8,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/walt_people.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=8,
33
+ workers_per_gpu=8,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/walt_vehicle.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=5,
33
+ workers_per_gpu=5,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/default_runtime.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_config = dict(interval=1)
2
+ # yapf:disable
3
+ log_config = dict(
4
+ interval=50,
5
+ hooks=[
6
+ dict(type='TextLoggerHook'),
7
+ # dict(type='TensorboardLoggerHook')
8
+ ])
9
+ # yapf:enable
10
+ custom_hooks = [dict(type='NumClassCheckHook')]
11
+
12
+ dist_params = dict(backend='nccl')
13
+ log_level = 'INFO'
14
+ load_from = None
15
+ resume_from = None
16
+ workflow = [('train', 1)]
configs/_base_/models/mask_rcnn_swin_fpn.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ model = dict(
3
+ type='MaskRCNN',
4
+ pretrained=None,
5
+ backbone=dict(
6
+ type='SwinTransformer',
7
+ embed_dim=96,
8
+ depths=[2, 2, 6, 2],
9
+ num_heads=[3, 6, 12, 24],
10
+ window_size=7,
11
+ mlp_ratio=4.,
12
+ qkv_bias=True,
13
+ qk_scale=None,
14
+ drop_rate=0.,
15
+ attn_drop_rate=0.,
16
+ drop_path_rate=0.2,
17
+ ape=False,
18
+ patch_norm=True,
19
+ out_indices=(0, 1, 2, 3),
20
+ use_checkpoint=False),
21
+ neck=dict(
22
+ type='FPN',
23
+ in_channels=[96, 192, 384, 768],
24
+ out_channels=256,
25
+ num_outs=5),
26
+ rpn_head=dict(
27
+ type='RPNHead',
28
+ in_channels=256,
29
+ feat_channels=256,
30
+ anchor_generator=dict(
31
+ type='AnchorGenerator',
32
+ scales=[8],
33
+ ratios=[0.5, 1.0, 2.0],
34
+ strides=[4, 8, 16, 32, 64]),
35
+ bbox_coder=dict(
36
+ type='DeltaXYWHBBoxCoder',
37
+ target_means=[.0, .0, .0, .0],
38
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
39
+ loss_cls=dict(
40
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
41
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
42
+ roi_head=dict(
43
+ type='StandardRoIHead',
44
+ bbox_roi_extractor=dict(
45
+ type='SingleRoIExtractor',
46
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
47
+ out_channels=256,
48
+ featmap_strides=[4, 8, 16, 32]),
49
+ bbox_head=dict(
50
+ type='Shared2FCBBoxHead',
51
+ in_channels=256,
52
+ fc_out_channels=1024,
53
+ roi_feat_size=7,
54
+ num_classes=80,
55
+ bbox_coder=dict(
56
+ type='DeltaXYWHBBoxCoder',
57
+ target_means=[0., 0., 0., 0.],
58
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
59
+ reg_class_agnostic=False,
60
+ loss_cls=dict(
61
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
62
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
63
+ mask_roi_extractor=dict(
64
+ type='SingleRoIExtractor',
65
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
66
+ out_channels=256,
67
+ featmap_strides=[4, 8, 16, 32]),
68
+ mask_head=dict(
69
+ type='FCNMaskHead',
70
+ num_convs=4,
71
+ in_channels=256,
72
+ conv_out_channels=256,
73
+ num_classes=80,
74
+ loss_mask=dict(
75
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
76
+ # model training and testing settings
77
+ train_cfg=dict(
78
+ rpn=dict(
79
+ assigner=dict(
80
+ type='MaxIoUAssigner',
81
+ pos_iou_thr=0.7,
82
+ neg_iou_thr=0.3,
83
+ min_pos_iou=0.3,
84
+ match_low_quality=True,
85
+ ignore_iof_thr=-1),
86
+ sampler=dict(
87
+ type='RandomSampler',
88
+ num=256,
89
+ pos_fraction=0.5,
90
+ neg_pos_ub=-1,
91
+ add_gt_as_proposals=False),
92
+ allowed_border=-1,
93
+ pos_weight=-1,
94
+ debug=False),
95
+ rpn_proposal=dict(
96
+ nms_pre=2000,
97
+ max_per_img=1000,
98
+ nms=dict(type='nms', iou_threshold=0.7),
99
+ min_bbox_size=0),
100
+ rcnn=dict(
101
+ assigner=dict(
102
+ type='MaxIoUAssigner',
103
+ pos_iou_thr=0.5,
104
+ neg_iou_thr=0.5,
105
+ min_pos_iou=0.5,
106
+ match_low_quality=True,
107
+ ignore_iof_thr=-1),
108
+ sampler=dict(
109
+ type='RandomSampler',
110
+ num=512,
111
+ pos_fraction=0.25,
112
+ neg_pos_ub=-1,
113
+ add_gt_as_proposals=True),
114
+ mask_size=28,
115
+ pos_weight=-1,
116
+ debug=False)),
117
+ test_cfg=dict(
118
+ rpn=dict(
119
+ nms_pre=1000,
120
+ max_per_img=1000,
121
+ nms=dict(type='nms', iou_threshold=0.7),
122
+ min_bbox_size=0),
123
+ rcnn=dict(
124
+ score_thr=0.05,
125
+ nms=dict(type='nms', iou_threshold=0.5),
126
+ max_per_img=100,
127
+ mask_thr_binary=0.5)))
configs/_base_/models/occ_mask_rcnn_swin_fpn.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ model = dict(
3
+ type='MaskRCNN',
4
+ pretrained=None,
5
+ backbone=dict(
6
+ type='SwinTransformer',
7
+ embed_dim=96,
8
+ depths=[2, 2, 6, 2],
9
+ num_heads=[3, 6, 12, 24],
10
+ window_size=7,
11
+ mlp_ratio=4.,
12
+ qkv_bias=True,
13
+ qk_scale=None,
14
+ drop_rate=0.,
15
+ attn_drop_rate=0.,
16
+ drop_path_rate=0.2,
17
+ ape=False,
18
+ patch_norm=True,
19
+ out_indices=(0, 1, 2, 3),
20
+ use_checkpoint=False),
21
+ neck=dict(
22
+ type='FPN',
23
+ in_channels=[96, 192, 384, 768],
24
+ out_channels=256,
25
+ num_outs=5),
26
+ rpn_head=dict(
27
+ type='RPNHead',
28
+ in_channels=256,
29
+ feat_channels=256,
30
+ anchor_generator=dict(
31
+ type='AnchorGenerator',
32
+ scales=[8],
33
+ ratios=[0.5, 1.0, 2.0],
34
+ strides=[4, 8, 16, 32, 64]),
35
+ bbox_coder=dict(
36
+ type='DeltaXYWHBBoxCoder',
37
+ target_means=[.0, .0, .0, .0],
38
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
39
+ loss_cls=dict(
40
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
41
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
42
+ roi_head=dict(
43
+ type='StandardRoIHead',
44
+ bbox_roi_extractor=dict(
45
+ type='SingleRoIExtractor',
46
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
47
+ out_channels=256,
48
+ featmap_strides=[4, 8, 16, 32]),
49
+ bbox_head=dict(
50
+ type='Shared2FCBBoxHead',
51
+ in_channels=256,
52
+ fc_out_channels=1024,
53
+ roi_feat_size=7,
54
+ num_classes=80,
55
+ bbox_coder=dict(
56
+ type='DeltaXYWHBBoxCoder',
57
+ target_means=[0., 0., 0., 0.],
58
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
59
+ reg_class_agnostic=False,
60
+ loss_cls=dict(
61
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
62
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
63
+ mask_roi_extractor=dict(
64
+ type='SingleRoIExtractor',
65
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
66
+ out_channels=256,
67
+ featmap_strides=[4, 8, 16, 32]),
68
+ mask_head=dict(
69
+ type='FCNOccMaskHead',
70
+ num_convs=4,
71
+ in_channels=256,
72
+ conv_out_channels=256,
73
+ num_classes=80,
74
+ loss_mask=dict(
75
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
76
+ # model training and testing settings
77
+ train_cfg=dict(
78
+ rpn=dict(
79
+ assigner=dict(
80
+ type='MaxIoUAssigner',
81
+ pos_iou_thr=0.7,
82
+ neg_iou_thr=0.3,
83
+ min_pos_iou=0.3,
84
+ match_low_quality=True,
85
+ ignore_iof_thr=-1),
86
+ sampler=dict(
87
+ type='RandomSampler',
88
+ num=256,
89
+ pos_fraction=0.5,
90
+ neg_pos_ub=-1,
91
+ add_gt_as_proposals=False),
92
+ allowed_border=-1,
93
+ pos_weight=-1,
94
+ debug=False),
95
+ rpn_proposal=dict(
96
+ nms_pre=2000,
97
+ max_per_img=1000,
98
+ nms=dict(type='nms', iou_threshold=0.7),
99
+ min_bbox_size=0),
100
+ rcnn=dict(
101
+ assigner=dict(
102
+ type='MaxIoUAssigner',
103
+ pos_iou_thr=0.5,
104
+ neg_iou_thr=0.5,
105
+ min_pos_iou=0.5,
106
+ match_low_quality=True,
107
+ ignore_iof_thr=-1),
108
+ sampler=dict(
109
+ type='RandomSampler',
110
+ num=512,
111
+ pos_fraction=0.25,
112
+ neg_pos_ub=-1,
113
+ add_gt_as_proposals=True),
114
+ mask_size=28,
115
+ pos_weight=-1,
116
+ debug=False)),
117
+ test_cfg=dict(
118
+ rpn=dict(
119
+ nms_pre=1000,
120
+ max_per_img=1000,
121
+ nms=dict(type='nms', iou_threshold=0.7),
122
+ min_bbox_size=0),
123
+ rcnn=dict(
124
+ score_thr=0.05,
125
+ nms=dict(type='nms', iou_threshold=0.5),
126
+ max_per_img=100,
127
+ mask_thr_binary=0.5)))
configs/_base_/schedules/schedule_1x.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
3
+ optimizer_config = dict(grad_clip=None)
4
+ # learning policy
5
+ lr_config = dict(
6
+ policy='step',
7
+ warmup='linear',
8
+ warmup_iters=500,
9
+ warmup_ratio=0.001,
10
+ step=[8, 11])
11
+ runner = dict(type='EpochBasedRunner', max_epochs=12)
configs/walt/walt_people.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
3
+ '../_base_/datasets/walt_people.py',
4
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ embed_dim=96,
10
+ depths=[2, 2, 6, 2],
11
+ num_heads=[3, 6, 12, 24],
12
+ window_size=7,
13
+ ape=False,
14
+ drop_path_rate=0.1,
15
+ patch_norm=True,
16
+ use_checkpoint=False
17
+ ),
18
+ neck=dict(in_channels=[96, 192, 384, 768]))
19
+
20
+ img_norm_cfg = dict(
21
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
22
+
23
+ # augmentation strategy originates from DETR / Sparse RCNN
24
+ train_pipeline = [
25
+ dict(type='LoadImageFromFile'),
26
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
27
+ dict(type='RandomFlip', flip_ratio=0.5),
28
+ dict(type='AutoAugment',
29
+ policies=[
30
+ [
31
+ dict(type='Resize',
32
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
33
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
34
+ (736, 1333), (768, 1333), (800, 1333)],
35
+ multiscale_mode='value',
36
+ keep_ratio=True)
37
+ ],
38
+ [
39
+ dict(type='Resize',
40
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
41
+ multiscale_mode='value',
42
+ keep_ratio=True),
43
+ dict(type='RandomCrop',
44
+ crop_type='absolute_range',
45
+ crop_size=(384, 600),
46
+ allow_negative_crop=True),
47
+ dict(type='Resize',
48
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
49
+ (576, 1333), (608, 1333), (640, 1333),
50
+ (672, 1333), (704, 1333), (736, 1333),
51
+ (768, 1333), (800, 1333)],
52
+ multiscale_mode='value',
53
+ override=True,
54
+ keep_ratio=True)
55
+ ]
56
+ ]),
57
+ dict(type='Normalize', **img_norm_cfg),
58
+ dict(type='Pad', size_divisor=32),
59
+ dict(type='DefaultFormatBundle'),
60
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
61
+ ]
62
+ data = dict(train=dict(pipeline=train_pipeline))
63
+
64
+ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
65
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
66
+ 'relative_position_bias_table': dict(decay_mult=0.),
67
+ 'norm': dict(decay_mult=0.)}))
68
+ lr_config = dict(step=[8, 11])
69
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
70
+
71
+ # do not use mmdet version fp16
72
+ fp16 = None
73
+ optimizer_config = dict(
74
+ type="DistOptimizerHook",
75
+ update_interval=1,
76
+ grad_clip=None,
77
+ coalesce=True,
78
+ bucket_size_mb=-1,
79
+ use_fp16=True,
80
+ )
configs/walt/walt_vehicle.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
3
+ '../_base_/datasets/walt_vehicle.py',
4
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ embed_dim=96,
10
+ depths=[2, 2, 6, 2],
11
+ num_heads=[3, 6, 12, 24],
12
+ window_size=7,
13
+ ape=False,
14
+ drop_path_rate=0.1,
15
+ patch_norm=True,
16
+ use_checkpoint=False
17
+ ),
18
+ neck=dict(in_channels=[96, 192, 384, 768]))
19
+
20
+ img_norm_cfg = dict(
21
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
22
+
23
+ # augmentation strategy originates from DETR / Sparse RCNN
24
+ train_pipeline = [
25
+ dict(type='LoadImageFromFile'),
26
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
27
+ dict(type='RandomFlip', flip_ratio=0.5),
28
+ dict(type='AutoAugment',
29
+ policies=[
30
+ [
31
+ dict(type='Resize',
32
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
33
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
34
+ (736, 1333), (768, 1333), (800, 1333)],
35
+ multiscale_mode='value',
36
+ keep_ratio=True)
37
+ ],
38
+ [
39
+ dict(type='Resize',
40
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
41
+ multiscale_mode='value',
42
+ keep_ratio=True),
43
+ dict(type='RandomCrop',
44
+ crop_type='absolute_range',
45
+ crop_size=(384, 600),
46
+ allow_negative_crop=True),
47
+ dict(type='Resize',
48
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
49
+ (576, 1333), (608, 1333), (640, 1333),
50
+ (672, 1333), (704, 1333), (736, 1333),
51
+ (768, 1333), (800, 1333)],
52
+ multiscale_mode='value',
53
+ override=True,
54
+ keep_ratio=True)
55
+ ]
56
+ ]),
57
+ dict(type='Normalize', **img_norm_cfg),
58
+ dict(type='Pad', size_divisor=32),
59
+ dict(type='DefaultFormatBundle'),
60
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
61
+ ]
62
+ data = dict(train=dict(pipeline=train_pipeline))
63
+
64
+ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
65
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
66
+ 'relative_position_bias_table': dict(decay_mult=0.),
67
+ 'norm': dict(decay_mult=0.)}))
68
+ lr_config = dict(step=[8, 11])
69
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
70
+
71
+ # do not use mmdet version fp16
72
+ fp16 = None
73
+ optimizer_config = dict(
74
+ type="DistOptimizerHook",
75
+ update_interval=1,
76
+ grad_clip=None,
77
+ coalesce=True,
78
+ bucket_size_mb=-1,
79
+ use_fp16=True,
80
+ )
docker/Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG PYTORCH="1.9.0"
2
+ ARG CUDA="11.1"
3
+ ARG CUDNN="8"
4
+
5
+ FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6
+
7
+ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8
+ ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9
+ ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
10
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
11
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
12
+ RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
13
+ && apt-get clean \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Install MMCV
17
+ #RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
18
+ # -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
19
+ RUN pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
20
+ # Install MMDetection
21
+ RUN conda clean --all
22
+ RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
23
+ WORKDIR /mmdetection
24
+ ENV FORCE_CUDA="1"
25
+ RUN cd /mmdetection && git checkout 7bd39044f35aec4b90dd797b965777541a8678ff
26
+ RUN pip install -r requirements/build.txt
27
+ RUN pip install --no-cache-dir -e .
28
+ RUN apt-get update
29
+ RUN apt-get install -y vim
30
+ RUN pip uninstall -y pycocotools
31
+ RUN pip install mmpycocotools timm scikit-image imagesize
32
+
33
+
34
+ # make sure we don't overwrite some existing directory called "apex"
35
+ WORKDIR /tmp/unique_for_apex
36
+ # uninstall Apex if present, twice to make absolutely sure :)
37
+ RUN pip uninstall -y apex || :
38
+ RUN pip uninstall -y apex || :
39
+ # SHA is something the user can touch to force recreation of this Docker layer,
40
+ # and therefore force cloning of the latest version of Apex
41
+ RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
42
+ WORKDIR /tmp/unique_for_apex/apex
43
+ RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
44
+ RUN pip install seaborn sklearn imantics gradio
45
+ WORKDIR /code
46
+ ENTRYPOINT ["python", "app.py"]
47
+
48
+ #RUN git clone https://github.com/NVIDIA/apex
49
+ #RUN cd apex
50
+ #RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
51
+ #RUN pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
52
+
github_vis/cwalt.gif ADDED
github_vis/vis_cars.gif ADDED
github_vis/vis_people.gif ADDED
mmcv_custom/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import load_checkpoint
4
+
5
+ __all__ = ['load_checkpoint']
mmcv_custom/checkpoint.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import io
3
+ import os
4
+ import os.path as osp
5
+ import pkgutil
6
+ import time
7
+ import warnings
8
+ from collections import OrderedDict
9
+ from importlib import import_module
10
+ from tempfile import TemporaryDirectory
11
+
12
+ import torch
13
+ import torchvision
14
+ from torch.optim import Optimizer
15
+ from torch.utils import model_zoo
16
+ from torch.nn import functional as F
17
+
18
+ import mmcv
19
+ from mmcv.fileio import FileClient
20
+ from mmcv.fileio import load as load_file
21
+ from mmcv.parallel import is_module_wrapper
22
+ from mmcv.utils import mkdir_or_exist
23
+ from mmcv.runner import get_dist_info
24
+
25
+ ENV_MMCV_HOME = 'MMCV_HOME'
26
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
+ DEFAULT_CACHE_DIR = '~/.cache'
28
+
29
+
30
+ def _get_mmcv_home():
31
+ mmcv_home = os.path.expanduser(
32
+ os.getenv(
33
+ ENV_MMCV_HOME,
34
+ os.path.join(
35
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
+
37
+ mkdir_or_exist(mmcv_home)
38
+ return mmcv_home
39
+
40
+
41
+ def load_state_dict(module, state_dict, strict=False, logger=None):
42
+ """Load state_dict to a module.
43
+
44
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45
+ Default value for ``strict`` is set to ``False`` and the message for
46
+ param mismatch will be shown even if strict is False.
47
+
48
+ Args:
49
+ module (Module): Module that receives the state_dict.
50
+ state_dict (OrderedDict): Weights.
51
+ strict (bool): whether to strictly enforce that the keys
52
+ in :attr:`state_dict` match the keys returned by this module's
53
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
55
+ message. If not specified, print function will be used.
56
+ """
57
+ unexpected_keys = []
58
+ all_missing_keys = []
59
+ err_msg = []
60
+
61
+ metadata = getattr(state_dict, '_metadata', None)
62
+ state_dict = state_dict.copy()
63
+ if metadata is not None:
64
+ state_dict._metadata = metadata
65
+
66
+ # use _load_from_state_dict to enable checkpoint version control
67
+ def load(module, prefix=''):
68
+ # recursively check parallel module in case that the model has a
69
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
70
+ if is_module_wrapper(module):
71
+ module = module.module
72
+ local_metadata = {} if metadata is None else metadata.get(
73
+ prefix[:-1], {})
74
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75
+ all_missing_keys, unexpected_keys,
76
+ err_msg)
77
+ for name, child in module._modules.items():
78
+ if child is not None:
79
+ load(child, prefix + name + '.')
80
+
81
+ load(module)
82
+ load = None # break load->load reference cycle
83
+
84
+ # ignore "num_batches_tracked" of BN layers
85
+ missing_keys = [
86
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
87
+ ]
88
+
89
+ if unexpected_keys:
90
+ err_msg.append('unexpected key in source '
91
+ f'state_dict: {", ".join(unexpected_keys)}\n')
92
+ if missing_keys:
93
+ err_msg.append(
94
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95
+
96
+ rank, _ = get_dist_info()
97
+ if len(err_msg) > 0 and rank == 0:
98
+ err_msg.insert(
99
+ 0, 'The model and loaded state dict do not match exactly\n')
100
+ err_msg = '\n'.join(err_msg)
101
+ if strict:
102
+ raise RuntimeError(err_msg)
103
+ elif logger is not None:
104
+ logger.warning(err_msg)
105
+ else:
106
+ print(err_msg)
107
+
108
+
109
+ def load_url_dist(url, model_dir=None):
110
+ """In distributed setting, this function only download checkpoint at local
111
+ rank 0."""
112
+ rank, world_size = get_dist_info()
113
+ rank = int(os.environ.get('LOCAL_RANK', rank))
114
+ if rank == 0:
115
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
116
+ if world_size > 1:
117
+ torch.distributed.barrier()
118
+ if rank > 0:
119
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
120
+ return checkpoint
121
+
122
+
123
+ def load_pavimodel_dist(model_path, map_location=None):
124
+ """In distributed setting, this function only download checkpoint at local
125
+ rank 0."""
126
+ try:
127
+ from pavi import modelcloud
128
+ except ImportError:
129
+ raise ImportError(
130
+ 'Please install pavi to load checkpoint from modelcloud.')
131
+ rank, world_size = get_dist_info()
132
+ rank = int(os.environ.get('LOCAL_RANK', rank))
133
+ if rank == 0:
134
+ model = modelcloud.get(model_path)
135
+ with TemporaryDirectory() as tmp_dir:
136
+ downloaded_file = osp.join(tmp_dir, model.name)
137
+ model.download(downloaded_file)
138
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
139
+ if world_size > 1:
140
+ torch.distributed.barrier()
141
+ if rank > 0:
142
+ model = modelcloud.get(model_path)
143
+ with TemporaryDirectory() as tmp_dir:
144
+ downloaded_file = osp.join(tmp_dir, model.name)
145
+ model.download(downloaded_file)
146
+ checkpoint = torch.load(
147
+ downloaded_file, map_location=map_location)
148
+ return checkpoint
149
+
150
+
151
+ def load_fileclient_dist(filename, backend, map_location):
152
+ """In distributed setting, this function only download checkpoint at local
153
+ rank 0."""
154
+ rank, world_size = get_dist_info()
155
+ rank = int(os.environ.get('LOCAL_RANK', rank))
156
+ allowed_backends = ['ceph']
157
+ if backend not in allowed_backends:
158
+ raise ValueError(f'Load from Backend {backend} is not supported.')
159
+ if rank == 0:
160
+ fileclient = FileClient(backend=backend)
161
+ buffer = io.BytesIO(fileclient.get(filename))
162
+ checkpoint = torch.load(buffer, map_location=map_location)
163
+ if world_size > 1:
164
+ torch.distributed.barrier()
165
+ if rank > 0:
166
+ fileclient = FileClient(backend=backend)
167
+ buffer = io.BytesIO(fileclient.get(filename))
168
+ checkpoint = torch.load(buffer, map_location=map_location)
169
+ return checkpoint
170
+
171
+
172
+ def get_torchvision_models():
173
+ model_urls = dict()
174
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
175
+ if ispkg:
176
+ continue
177
+ _zoo = import_module(f'torchvision.models.{name}')
178
+ if hasattr(_zoo, 'model_urls'):
179
+ _urls = getattr(_zoo, 'model_urls')
180
+ model_urls.update(_urls)
181
+ return model_urls
182
+
183
+
184
+ def get_external_models():
185
+ mmcv_home = _get_mmcv_home()
186
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
187
+ default_urls = load_file(default_json_path)
188
+ assert isinstance(default_urls, dict)
189
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
190
+ if osp.exists(external_json_path):
191
+ external_urls = load_file(external_json_path)
192
+ assert isinstance(external_urls, dict)
193
+ default_urls.update(external_urls)
194
+
195
+ return default_urls
196
+
197
+
198
+ def get_mmcls_models():
199
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
200
+ mmcls_urls = load_file(mmcls_json_path)
201
+
202
+ return mmcls_urls
203
+
204
+
205
+ def get_deprecated_model_names():
206
+ deprecate_json_path = osp.join(mmcv.__path__[0],
207
+ 'model_zoo/deprecated.json')
208
+ deprecate_urls = load_file(deprecate_json_path)
209
+ assert isinstance(deprecate_urls, dict)
210
+
211
+ return deprecate_urls
212
+
213
+
214
+ def _process_mmcls_checkpoint(checkpoint):
215
+ state_dict = checkpoint['state_dict']
216
+ new_state_dict = OrderedDict()
217
+ for k, v in state_dict.items():
218
+ if k.startswith('backbone.'):
219
+ new_state_dict[k[9:]] = v
220
+ new_checkpoint = dict(state_dict=new_state_dict)
221
+
222
+ return new_checkpoint
223
+
224
+
225
+ def _load_checkpoint(filename, map_location=None):
226
+ """Load checkpoint from somewhere (modelzoo, file, url).
227
+
228
+ Args:
229
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
230
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
231
+ details.
232
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
233
+
234
+ Returns:
235
+ dict | OrderedDict: The loaded checkpoint. It can be either an
236
+ OrderedDict storing model weights or a dict containing other
237
+ information, which depends on the checkpoint.
238
+ """
239
+ if filename.startswith('modelzoo://'):
240
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
241
+ 'use "torchvision://" instead')
242
+ model_urls = get_torchvision_models()
243
+ model_name = filename[11:]
244
+ checkpoint = load_url_dist(model_urls[model_name])
245
+ elif filename.startswith('torchvision://'):
246
+ model_urls = get_torchvision_models()
247
+ model_name = filename[14:]
248
+ checkpoint = load_url_dist(model_urls[model_name])
249
+ elif filename.startswith('open-mmlab://'):
250
+ model_urls = get_external_models()
251
+ model_name = filename[13:]
252
+ deprecated_urls = get_deprecated_model_names()
253
+ if model_name in deprecated_urls:
254
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
255
+ f'of open-mmlab://{deprecated_urls[model_name]}')
256
+ model_name = deprecated_urls[model_name]
257
+ model_url = model_urls[model_name]
258
+ # check if is url
259
+ if model_url.startswith(('http://', 'https://')):
260
+ checkpoint = load_url_dist(model_url)
261
+ else:
262
+ filename = osp.join(_get_mmcv_home(), model_url)
263
+ if not osp.isfile(filename):
264
+ raise IOError(f'{filename} is not a checkpoint file')
265
+ checkpoint = torch.load(filename, map_location=map_location)
266
+ elif filename.startswith('mmcls://'):
267
+ model_urls = get_mmcls_models()
268
+ model_name = filename[8:]
269
+ checkpoint = load_url_dist(model_urls[model_name])
270
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
271
+ elif filename.startswith(('http://', 'https://')):
272
+ checkpoint = load_url_dist(filename)
273
+ elif filename.startswith('pavi://'):
274
+ model_path = filename[7:]
275
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
276
+ elif filename.startswith('s3://'):
277
+ checkpoint = load_fileclient_dist(
278
+ filename, backend='ceph', map_location=map_location)
279
+ else:
280
+ if not osp.isfile(filename):
281
+ raise IOError(f'{filename} is not a checkpoint file')
282
+ checkpoint = torch.load(filename, map_location=map_location)
283
+ return checkpoint
284
+
285
+
286
+ def load_checkpoint(model,
287
+ filename,
288
+ map_location='cpu',
289
+ strict=False,
290
+ logger=None):
291
+ """Load checkpoint from a file or URI.
292
+
293
+ Args:
294
+ model (Module): Module to load checkpoint.
295
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
296
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
297
+ details.
298
+ map_location (str): Same as :func:`torch.load`.
299
+ strict (bool): Whether to allow different params for the model and
300
+ checkpoint.
301
+ logger (:mod:`logging.Logger` or None): The logger for error message.
302
+
303
+ Returns:
304
+ dict or OrderedDict: The loaded checkpoint.
305
+ """
306
+ checkpoint = _load_checkpoint(filename, map_location)
307
+ # OrderedDict is a subclass of dict
308
+ if not isinstance(checkpoint, dict):
309
+ raise RuntimeError(
310
+ f'No state_dict found in checkpoint file {filename}')
311
+ # get state_dict from checkpoint
312
+ if 'state_dict' in checkpoint:
313
+ state_dict = checkpoint['state_dict']
314
+ elif 'model' in checkpoint:
315
+ state_dict = checkpoint['model']
316
+ else:
317
+ state_dict = checkpoint
318
+ # strip prefix of state_dict
319
+ if list(state_dict.keys())[0].startswith('module.'):
320
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
321
+
322
+ # for MoBY, load model of online branch
323
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
324
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
325
+
326
+ # reshape absolute position embedding
327
+ if state_dict.get('absolute_pos_embed') is not None:
328
+ absolute_pos_embed = state_dict['absolute_pos_embed']
329
+ N1, L, C1 = absolute_pos_embed.size()
330
+ N2, C2, H, W = model.absolute_pos_embed.size()
331
+ if N1 != N2 or C1 != C2 or L != H*W:
332
+ logger.warning("Error in loading absolute_pos_embed, pass")
333
+ else:
334
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
335
+
336
+ # interpolate position bias table if needed
337
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
338
+ for table_key in relative_position_bias_table_keys:
339
+ table_pretrained = state_dict[table_key]
340
+ table_current = model.state_dict()[table_key]
341
+ L1, nH1 = table_pretrained.size()
342
+ L2, nH2 = table_current.size()
343
+ if nH1 != nH2:
344
+ logger.warning(f"Error in loading {table_key}, pass")
345
+ else:
346
+ if L1 != L2:
347
+ S1 = int(L1 ** 0.5)
348
+ S2 = int(L2 ** 0.5)
349
+ table_pretrained_resized = F.interpolate(
350
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
351
+ size=(S2, S2), mode='bicubic')
352
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
353
+
354
+ # load state_dict
355
+ load_state_dict(model, state_dict, strict, logger)
356
+ return checkpoint
357
+
358
+
359
+ def weights_to_cpu(state_dict):
360
+ """Copy a model state_dict to cpu.
361
+
362
+ Args:
363
+ state_dict (OrderedDict): Model weights on GPU.
364
+
365
+ Returns:
366
+ OrderedDict: Model weights on GPU.
367
+ """
368
+ state_dict_cpu = OrderedDict()
369
+ for key, val in state_dict.items():
370
+ state_dict_cpu[key] = val.cpu()
371
+ return state_dict_cpu
372
+
373
+
374
+ def _save_to_state_dict(module, destination, prefix, keep_vars):
375
+ """Saves module state to `destination` dictionary.
376
+
377
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
378
+
379
+ Args:
380
+ module (nn.Module): The module to generate state_dict.
381
+ destination (dict): A dict where state will be stored.
382
+ prefix (str): The prefix for parameters and buffers used in this
383
+ module.
384
+ """
385
+ for name, param in module._parameters.items():
386
+ if param is not None:
387
+ destination[prefix + name] = param if keep_vars else param.detach()
388
+ for name, buf in module._buffers.items():
389
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
390
+ if buf is not None:
391
+ destination[prefix + name] = buf if keep_vars else buf.detach()
392
+
393
+
394
+ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
395
+ """Returns a dictionary containing a whole state of the module.
396
+
397
+ Both parameters and persistent buffers (e.g. running averages) are
398
+ included. Keys are corresponding parameter and buffer names.
399
+
400
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
401
+ recursively check parallel module in case that the model has a complicated
402
+ structure, e.g., nn.Module(nn.Module(DDP)).
403
+
404
+ Args:
405
+ module (nn.Module): The module to generate state_dict.
406
+ destination (OrderedDict): Returned dict for the state of the
407
+ module.
408
+ prefix (str): Prefix of the key.
409
+ keep_vars (bool): Whether to keep the variable property of the
410
+ parameters. Default: False.
411
+
412
+ Returns:
413
+ dict: A dictionary containing a whole state of the module.
414
+ """
415
+ # recursively check parallel module in case that the model has a
416
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
417
+ if is_module_wrapper(module):
418
+ module = module.module
419
+
420
+ # below is the same as torch.nn.Module.state_dict()
421
+ if destination is None:
422
+ destination = OrderedDict()
423
+ destination._metadata = OrderedDict()
424
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
425
+ version=module._version)
426
+ _save_to_state_dict(module, destination, prefix, keep_vars)
427
+ for name, child in module._modules.items():
428
+ if child is not None:
429
+ get_state_dict(
430
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
431
+ for hook in module._state_dict_hooks.values():
432
+ hook_result = hook(module, destination, prefix, local_metadata)
433
+ if hook_result is not None:
434
+ destination = hook_result
435
+ return destination
436
+
437
+
438
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
439
+ """Save checkpoint to file.
440
+
441
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
442
+ ``optimizer``. By default ``meta`` will contain version and time info.
443
+
444
+ Args:
445
+ model (Module): Module whose params are to be saved.
446
+ filename (str): Checkpoint filename.
447
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
448
+ meta (dict, optional): Metadata to be saved in checkpoint.
449
+ """
450
+ if meta is None:
451
+ meta = {}
452
+ elif not isinstance(meta, dict):
453
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
454
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
455
+
456
+ if is_module_wrapper(model):
457
+ model = model.module
458
+
459
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
460
+ # save class name to the meta
461
+ meta.update(CLASSES=model.CLASSES)
462
+
463
+ checkpoint = {
464
+ 'meta': meta,
465
+ 'state_dict': weights_to_cpu(get_state_dict(model))
466
+ }
467
+ # save optimizer state dict in the checkpoint
468
+ if isinstance(optimizer, Optimizer):
469
+ checkpoint['optimizer'] = optimizer.state_dict()
470
+ elif isinstance(optimizer, dict):
471
+ checkpoint['optimizer'] = {}
472
+ for name, optim in optimizer.items():
473
+ checkpoint['optimizer'][name] = optim.state_dict()
474
+
475
+ if filename.startswith('pavi://'):
476
+ try:
477
+ from pavi import modelcloud
478
+ from pavi.exception import NodeNotFoundError
479
+ except ImportError:
480
+ raise ImportError(
481
+ 'Please install pavi to load checkpoint from modelcloud.')
482
+ model_path = filename[7:]
483
+ root = modelcloud.Folder()
484
+ model_dir, model_name = osp.split(model_path)
485
+ try:
486
+ model = modelcloud.get(model_dir)
487
+ except NodeNotFoundError:
488
+ model = root.create_training_model(model_dir)
489
+ with TemporaryDirectory() as tmp_dir:
490
+ checkpoint_file = osp.join(tmp_dir, model_name)
491
+ with open(checkpoint_file, 'wb') as f:
492
+ torch.save(checkpoint, f)
493
+ f.flush()
494
+ model.create_file(checkpoint_file, name=model_name)
495
+ else:
496
+ mmcv.mkdir_or_exist(osp.dirname(filename))
497
+ # immediately flush buffer
498
+ with open(filename, 'wb') as f:
499
+ torch.save(checkpoint, f)
500
+ f.flush()
mmcv_custom/runner/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ from .checkpoint import save_checkpoint
3
+ from .epoch_based_runner import EpochBasedRunnerAmp
4
+
5
+
6
+ __all__ = [
7
+ 'EpochBasedRunnerAmp', 'save_checkpoint'
8
+ ]
mmcv_custom/runner/checkpoint.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import time
4
+ from tempfile import TemporaryDirectory
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.parallel import is_module_wrapper
11
+ from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
20
+ """Save checkpoint to file.
21
+
22
+ The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23
+ ``optimizer``, ``amp``. By default ``meta`` will contain version
24
+ and time info.
25
+
26
+ Args:
27
+ model (Module): Module whose params are to be saved.
28
+ filename (str): Checkpoint filename.
29
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30
+ meta (dict, optional): Metadata to be saved in checkpoint.
31
+ """
32
+ if meta is None:
33
+ meta = {}
34
+ elif not isinstance(meta, dict):
35
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37
+
38
+ if is_module_wrapper(model):
39
+ model = model.module
40
+
41
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42
+ # save class name to the meta
43
+ meta.update(CLASSES=model.CLASSES)
44
+
45
+ checkpoint = {
46
+ 'meta': meta,
47
+ 'state_dict': weights_to_cpu(get_state_dict(model))
48
+ }
49
+ # save optimizer state dict in the checkpoint
50
+ if isinstance(optimizer, Optimizer):
51
+ checkpoint['optimizer'] = optimizer.state_dict()
52
+ elif isinstance(optimizer, dict):
53
+ checkpoint['optimizer'] = {}
54
+ for name, optim in optimizer.items():
55
+ checkpoint['optimizer'][name] = optim.state_dict()
56
+
57
+ # save amp state dict in the checkpoint
58
+ checkpoint['amp'] = apex.amp.state_dict()
59
+
60
+ if filename.startswith('pavi://'):
61
+ try:
62
+ from pavi import modelcloud
63
+ from pavi.exception import NodeNotFoundError
64
+ except ImportError:
65
+ raise ImportError(
66
+ 'Please install pavi to load checkpoint from modelcloud.')
67
+ model_path = filename[7:]
68
+ root = modelcloud.Folder()
69
+ model_dir, model_name = osp.split(model_path)
70
+ try:
71
+ model = modelcloud.get(model_dir)
72
+ except NodeNotFoundError:
73
+ model = root.create_training_model(model_dir)
74
+ with TemporaryDirectory() as tmp_dir:
75
+ checkpoint_file = osp.join(tmp_dir, model_name)
76
+ with open(checkpoint_file, 'wb') as f:
77
+ torch.save(checkpoint, f)
78
+ f.flush()
79
+ model.create_file(checkpoint_file, name=model_name)
80
+ else:
81
+ mmcv.mkdir_or_exist(osp.dirname(filename))
82
+ # immediately flush buffer
83
+ with open(filename, 'wb') as f:
84
+ torch.save(checkpoint, f)
85
+ f.flush()
mmcv_custom/runner/epoch_based_runner.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import platform
4
+ import shutil
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.runner import RUNNERS, EpochBasedRunner
11
+ from .checkpoint import save_checkpoint
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ @RUNNERS.register_module()
20
+ class EpochBasedRunnerAmp(EpochBasedRunner):
21
+ """Epoch-based Runner with AMP support.
22
+
23
+ This runner train models epoch by epoch.
24
+ """
25
+
26
+ def save_checkpoint(self,
27
+ out_dir,
28
+ filename_tmpl='epoch_{}.pth',
29
+ save_optimizer=True,
30
+ meta=None,
31
+ create_symlink=True):
32
+ """Save the checkpoint.
33
+
34
+ Args:
35
+ out_dir (str): The directory that checkpoints are saved.
36
+ filename_tmpl (str, optional): The checkpoint filename template,
37
+ which contains a placeholder for the epoch number.
38
+ Defaults to 'epoch_{}.pth'.
39
+ save_optimizer (bool, optional): Whether to save the optimizer to
40
+ the checkpoint. Defaults to True.
41
+ meta (dict, optional): The meta information to be saved in the
42
+ checkpoint. Defaults to None.
43
+ create_symlink (bool, optional): Whether to create a symlink
44
+ "latest.pth" to point to the latest checkpoint.
45
+ Defaults to True.
46
+ """
47
+ if meta is None:
48
+ meta = dict(epoch=self.epoch + 1, iter=self.iter)
49
+ elif isinstance(meta, dict):
50
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
51
+ else:
52
+ raise TypeError(
53
+ f'meta should be a dict or None, but got {type(meta)}')
54
+ if self.meta is not None:
55
+ meta.update(self.meta)
56
+
57
+ filename = filename_tmpl.format(self.epoch + 1)
58
+ filepath = osp.join(out_dir, filename)
59
+ optimizer = self.optimizer if save_optimizer else None
60
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
61
+ # in some environments, `os.symlink` is not supported, you may need to
62
+ # set `create_symlink` to False
63
+ if create_symlink:
64
+ dst_file = osp.join(out_dir, 'latest.pth')
65
+ if platform.system() != 'Windows':
66
+ mmcv.symlink(filename, dst_file)
67
+ else:
68
+ shutil.copy(filepath, dst_file)
69
+
70
+ def resume(self,
71
+ checkpoint,
72
+ resume_optimizer=True,
73
+ map_location='default'):
74
+ if map_location == 'default':
75
+ if torch.cuda.is_available():
76
+ device_id = torch.cuda.current_device()
77
+ checkpoint = self.load_checkpoint(
78
+ checkpoint,
79
+ map_location=lambda storage, loc: storage.cuda(device_id))
80
+ else:
81
+ checkpoint = self.load_checkpoint(checkpoint)
82
+ else:
83
+ checkpoint = self.load_checkpoint(
84
+ checkpoint, map_location=map_location)
85
+
86
+ self._epoch = checkpoint['meta']['epoch']
87
+ self._iter = checkpoint['meta']['iter']
88
+ if 'optimizer' in checkpoint and resume_optimizer:
89
+ if isinstance(self.optimizer, Optimizer):
90
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
91
+ elif isinstance(self.optimizer, dict):
92
+ for k in self.optimizer.keys():
93
+ self.optimizer[k].load_state_dict(
94
+ checkpoint['optimizer'][k])
95
+ else:
96
+ raise TypeError(
97
+ 'Optimizer should be dict or torch.optim.Optimizer '
98
+ f'but got {type(self.optimizer)}')
99
+
100
+ if 'amp' in checkpoint:
101
+ apex.amp.load_state_dict(checkpoint['amp'])
102
+ self.logger.info('load amp state dict')
103
+
104
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
mmdet/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+
3
+ from .version import __version__, short_version
4
+
5
+
6
+ def digit_version(version_str):
7
+ digit_version = []
8
+ for x in version_str.split('.'):
9
+ if x.isdigit():
10
+ digit_version.append(int(x))
11
+ elif x.find('rc') != -1:
12
+ patch_version = x.split('rc')
13
+ digit_version.append(int(patch_version[0]) - 1)
14
+ digit_version.append(int(patch_version[1]))
15
+ return digit_version
16
+
17
+
18
+ mmcv_minimum_version = '1.2.4'
19
+ mmcv_maximum_version = '1.4.0'
20
+ mmcv_version = digit_version(mmcv.__version__)
21
+
22
+
23
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
24
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
25
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
26
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
27
+
28
+ __all__ = ['__version__', 'short_version']
mmdet/apis/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import (async_inference_detector, inference_detector,
2
+ init_detector, show_result_pyplot)
3
+ from .test import multi_gpu_test, single_gpu_test
4
+ from .train import get_root_logger, set_random_seed, train_detector
5
+
6
+ __all__ = [
7
+ 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
8
+ 'async_inference_detector', 'inference_detector', 'show_result_pyplot',
9
+ 'multi_gpu_test', 'single_gpu_test'
10
+ ]
mmdet/apis/inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import mmcv
4
+ import numpy as np
5
+ import torch
6
+ from mmcv.ops import RoIPool
7
+ from mmcv.parallel import collate, scatter
8
+ from mmcv.runner import load_checkpoint
9
+
10
+ from mmdet.core import get_classes
11
+ from mmdet.datasets import replace_ImageToTensor
12
+ from mmdet.datasets.pipelines import Compose
13
+ from mmdet.models import build_detector
14
+
15
+
16
+ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
17
+ """Initialize a detector from config file.
18
+
19
+ Args:
20
+ config (str or :obj:`mmcv.Config`): Config file path or the config
21
+ object.
22
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
23
+ will not load any weights.
24
+ cfg_options (dict): Options to override some settings in the used
25
+ config.
26
+
27
+ Returns:
28
+ nn.Module: The constructed detector.
29
+ """
30
+ if isinstance(config, str):
31
+ config = mmcv.Config.fromfile(config)
32
+ elif not isinstance(config, mmcv.Config):
33
+ raise TypeError('config must be a filename or Config object, '
34
+ f'but got {type(config)}')
35
+ if cfg_options is not None:
36
+ config.merge_from_dict(cfg_options)
37
+ config.model.pretrained = None
38
+ config.model.train_cfg = None
39
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
40
+ if checkpoint is not None:
41
+ map_loc = 'cpu' if device == 'cpu' else None
42
+ checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
43
+ if 'CLASSES' in checkpoint.get('meta', {}):
44
+ model.CLASSES = checkpoint['meta']['CLASSES']
45
+ else:
46
+ warnings.simplefilter('once')
47
+ warnings.warn('Class names are not saved in the checkpoint\'s '
48
+ 'meta data, use COCO classes by default.')
49
+ model.CLASSES = get_classes('coco')
50
+ model.cfg = config # save the config in the model for convenience
51
+ model.to(device)
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ class LoadImage(object):
57
+ """Deprecated.
58
+
59
+ A simple pipeline to load image.
60
+ """
61
+
62
+ def __call__(self, results):
63
+ """Call function to load images into results.
64
+
65
+ Args:
66
+ results (dict): A result dict contains the file name
67
+ of the image to be read.
68
+ Returns:
69
+ dict: ``results`` will be returned containing loaded image.
70
+ """
71
+ warnings.simplefilter('once')
72
+ warnings.warn('`LoadImage` is deprecated and will be removed in '
73
+ 'future releases. You may use `LoadImageFromWebcam` '
74
+ 'from `mmdet.datasets.pipelines.` instead.')
75
+ if isinstance(results['img'], str):
76
+ results['filename'] = results['img']
77
+ results['ori_filename'] = results['img']
78
+ else:
79
+ results['filename'] = None
80
+ results['ori_filename'] = None
81
+ img = mmcv.imread(results['img'])
82
+ results['img'] = img
83
+ results['img_fields'] = ['img']
84
+ results['img_shape'] = img.shape
85
+ results['ori_shape'] = img.shape
86
+ return results
87
+
88
+
89
+ def inference_detector(model, imgs):
90
+ """Inference image(s) with the detector.
91
+
92
+ Args:
93
+ model (nn.Module): The loaded detector.
94
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
95
+ Either image files or loaded images.
96
+
97
+ Returns:
98
+ If imgs is a list or tuple, the same length list type results
99
+ will be returned, otherwise return the detection results directly.
100
+ """
101
+
102
+ if isinstance(imgs, (list, tuple)):
103
+ is_batch = True
104
+ else:
105
+ imgs = [imgs]
106
+ is_batch = False
107
+
108
+ cfg = model.cfg
109
+ device = next(model.parameters()).device # model device
110
+
111
+ if isinstance(imgs[0], np.ndarray):
112
+ cfg = cfg.copy()
113
+ # set loading pipeline type
114
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
115
+
116
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
117
+ test_pipeline = Compose(cfg.data.test.pipeline)
118
+
119
+ datas = []
120
+ for img in imgs:
121
+ # prepare data
122
+ if isinstance(img, np.ndarray):
123
+ # directly add img
124
+ data = dict(img=img)
125
+ else:
126
+ # add information into dict
127
+ data = dict(img_info=dict(filename=img), img_prefix=None)
128
+ # build the data pipeline
129
+ data = test_pipeline(data)
130
+ datas.append(data)
131
+
132
+ data = collate(datas, samples_per_gpu=len(imgs))
133
+ # just get the actual data from DataContainer
134
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
135
+ data['img'] = [img.data[0] for img in data['img']]
136
+ if next(model.parameters()).is_cuda:
137
+ # scatter to specified GPU
138
+ data = scatter(data, [device])[0]
139
+ else:
140
+ for m in model.modules():
141
+ assert not isinstance(
142
+ m, RoIPool
143
+ ), 'CPU inference with RoIPool is not supported currently.'
144
+
145
+ # forward the model
146
+ with torch.no_grad():
147
+ results = model(return_loss=False, rescale=True, **data)
148
+
149
+ if not is_batch:
150
+ return results[0]
151
+ else:
152
+ return results
153
+
154
+
155
+ async def async_inference_detector(model, img):
156
+ """Async inference image(s) with the detector.
157
+
158
+ Args:
159
+ model (nn.Module): The loaded detector.
160
+ img (str | ndarray): Either image files or loaded images.
161
+
162
+ Returns:
163
+ Awaitable detection results.
164
+ """
165
+ cfg = model.cfg
166
+ device = next(model.parameters()).device # model device
167
+ # prepare data
168
+ if isinstance(img, np.ndarray):
169
+ # directly add img
170
+ data = dict(img=img)
171
+ cfg = cfg.copy()
172
+ # set loading pipeline type
173
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
174
+ else:
175
+ # add information into dict
176
+ data = dict(img_info=dict(filename=img), img_prefix=None)
177
+ # build the data pipeline
178
+ test_pipeline = Compose(cfg.data.test.pipeline)
179
+ data = test_pipeline(data)
180
+ data = scatter(collate([data], samples_per_gpu=1), [device])[0]
181
+
182
+ # We don't restore `torch.is_grad_enabled()` value during concurrent
183
+ # inference since execution can overlap
184
+ torch.set_grad_enabled(False)
185
+ result = await model.aforward_test(rescale=True, **data)
186
+ return result
187
+
188
+
189
+ def show_result_pyplot(model,
190
+ img,
191
+ result,
192
+ score_thr=0.3,
193
+ title='result',
194
+ wait_time=0):
195
+ """Visualize the detection results on the image.
196
+
197
+ Args:
198
+ model (nn.Module): The loaded detector.
199
+ img (str or np.ndarray): Image filename or loaded image.
200
+ result (tuple[list] or list): The detection result, can be either
201
+ (bbox, segm) or just bbox.
202
+ score_thr (float): The threshold to visualize the bboxes and masks.
203
+ title (str): Title of the pyplot figure.
204
+ wait_time (float): Value of waitKey param.
205
+ Default: 0.
206
+ """
207
+ if hasattr(model, 'module'):
208
+ model = model.module
209
+ model.show_result(
210
+ img,
211
+ result,
212
+ score_thr=score_thr,
213
+ show=True,
214
+ wait_time=wait_time,
215
+ win_name=title,
216
+ bbox_color=(72, 101, 241),
217
+ text_color=(72, 101, 241))
mmdet/apis/test.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import pickle
3
+ import shutil
4
+ import tempfile
5
+ import time
6
+
7
+ import mmcv
8
+ import torch
9
+ import torch.distributed as dist
10
+ from mmcv.image import tensor2imgs
11
+ from mmcv.runner import get_dist_info
12
+
13
+ from mmdet.core import encode_mask_results
14
+
15
+
16
+ def single_gpu_test(model,
17
+ data_loader,
18
+ show=False,
19
+ out_dir=None,
20
+ show_score_thr=0.3):
21
+ model.eval()
22
+ results = []
23
+ dataset = data_loader.dataset
24
+ prog_bar = mmcv.ProgressBar(len(dataset))
25
+ for i, data in enumerate(data_loader):
26
+ with torch.no_grad():
27
+ result = model(return_loss=False, rescale=True, **data)
28
+
29
+ batch_size = len(result)
30
+ if show or out_dir:
31
+ if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
32
+ img_tensor = data['img'][0]
33
+ else:
34
+ img_tensor = data['img'][0].data[0]
35
+ img_metas = data['img_metas'][0].data[0]
36
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
37
+ assert len(imgs) == len(img_metas)
38
+
39
+ for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
40
+ h, w, _ = img_meta['img_shape']
41
+ img_show = img[:h, :w, :]
42
+
43
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
44
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
45
+
46
+ if out_dir:
47
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
48
+ else:
49
+ out_file = None
50
+ model.module.show_result(
51
+ img_show,
52
+ result[i],
53
+ show=show,
54
+ out_file=out_file,
55
+ score_thr=show_score_thr)
56
+
57
+ # encode mask results
58
+ if isinstance(result[0], tuple):
59
+ result = [(bbox_results, encode_mask_results(mask_results))
60
+ for bbox_results, mask_results in result]
61
+ results.extend(result)
62
+
63
+ for _ in range(batch_size):
64
+ prog_bar.update()
65
+ return results
66
+
67
+
68
+ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
69
+ """Test model with multiple gpus.
70
+
71
+ This method tests model with multiple gpus and collects the results
72
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
73
+ it encodes results to gpu tensors and use gpu communication for results
74
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
75
+ and collects them by the rank 0 worker.
76
+
77
+ Args:
78
+ model (nn.Module): Model to be tested.
79
+ data_loader (nn.Dataloader): Pytorch data loader.
80
+ tmpdir (str): Path of directory to save the temporary results from
81
+ different gpus under cpu mode.
82
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
83
+
84
+ Returns:
85
+ list: The prediction results.
86
+ """
87
+ model.eval()
88
+ results = []
89
+ dataset = data_loader.dataset
90
+ rank, world_size = get_dist_info()
91
+ if rank == 0:
92
+ prog_bar = mmcv.ProgressBar(len(dataset))
93
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
94
+ for i, data in enumerate(data_loader):
95
+ with torch.no_grad():
96
+ result = model(return_loss=False, rescale=True, **data)
97
+ # encode mask results
98
+ if isinstance(result[0], tuple):
99
+ result = [(bbox_results, encode_mask_results(mask_results))
100
+ for bbox_results, mask_results in result]
101
+ results.extend(result)
102
+
103
+ if rank == 0:
104
+ batch_size = len(result)
105
+ for _ in range(batch_size * world_size):
106
+ prog_bar.update()
107
+
108
+ # collect results from all ranks
109
+ if gpu_collect:
110
+ results = collect_results_gpu(results, len(dataset))
111
+ else:
112
+ results = collect_results_cpu(results, len(dataset), tmpdir)
113
+ return results
114
+
115
+
116
+ def collect_results_cpu(result_part, size, tmpdir=None):
117
+ rank, world_size = get_dist_info()
118
+ # create a tmp dir if it is not specified
119
+ if tmpdir is None:
120
+ MAX_LEN = 512
121
+ # 32 is whitespace
122
+ dir_tensor = torch.full((MAX_LEN, ),
123
+ 32,
124
+ dtype=torch.uint8,
125
+ device='cuda')
126
+ if rank == 0:
127
+ mmcv.mkdir_or_exist('.dist_test')
128
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
129
+ tmpdir = torch.tensor(
130
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
131
+ dir_tensor[:len(tmpdir)] = tmpdir
132
+ dist.broadcast(dir_tensor, 0)
133
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
134
+ else:
135
+ mmcv.mkdir_or_exist(tmpdir)
136
+ # dump the part result to the dir
137
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
138
+ dist.barrier()
139
+ # collect all parts
140
+ if rank != 0:
141
+ return None
142
+ else:
143
+ # load results of all parts from tmp dir
144
+ part_list = []
145
+ for i in range(world_size):
146
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
147
+ part_list.append(mmcv.load(part_file))
148
+ # sort the results
149
+ ordered_results = []
150
+ for res in zip(*part_list):
151
+ ordered_results.extend(list(res))
152
+ # the dataloader may pad some samples
153
+ ordered_results = ordered_results[:size]
154
+ # remove tmp dir
155
+ shutil.rmtree(tmpdir)
156
+ return ordered_results
157
+
158
+
159
+ def collect_results_gpu(result_part, size):
160
+ rank, world_size = get_dist_info()
161
+ # dump result part to tensor with pickle
162
+ part_tensor = torch.tensor(
163
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
164
+ # gather all result part tensor shape
165
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
166
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
167
+ dist.all_gather(shape_list, shape_tensor)
168
+ # padding result part tensor to max length
169
+ shape_max = torch.tensor(shape_list).max()
170
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
171
+ part_send[:shape_tensor[0]] = part_tensor
172
+ part_recv_list = [
173
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
174
+ ]
175
+ # gather all result part
176
+ dist.all_gather(part_recv_list, part_send)
177
+
178
+ if rank == 0:
179
+ part_list = []
180
+ for recv, shape in zip(part_recv_list, shape_list):
181
+ part_list.append(
182
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
183
+ # sort the results
184
+ ordered_results = []
185
+ for res in zip(*part_list):
186
+ ordered_results.extend(list(res))
187
+ # the dataloader may pad some samples
188
+ ordered_results = ordered_results[:size]
189
+ return ordered_results
mmdet/apis/train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
+ from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
8
+ Fp16OptimizerHook, OptimizerHook, build_optimizer,
9
+ build_runner)
10
+ from mmcv.utils import build_from_cfg
11
+
12
+ from mmdet.core import DistEvalHook, EvalHook
13
+ from mmdet.datasets import (build_dataloader, build_dataset,
14
+ replace_ImageToTensor)
15
+ from mmdet.utils import get_root_logger
16
+ from mmcv_custom.runner import EpochBasedRunnerAmp
17
+ try:
18
+ import apex
19
+ except:
20
+ print('apex is not installed')
21
+
22
+
23
+ def set_random_seed(seed, deterministic=False):
24
+ """Set random seed.
25
+
26
+ Args:
27
+ seed (int): Seed to be used.
28
+ deterministic (bool): Whether to set the deterministic option for
29
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
30
+ to True and `torch.backends.cudnn.benchmark` to False.
31
+ Default: False.
32
+ """
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ if deterministic:
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+
42
+ def train_detector(model,
43
+ dataset,
44
+ cfg,
45
+ distributed=False,
46
+ validate=False,
47
+ timestamp=None,
48
+ meta=None):
49
+ logger = get_root_logger(cfg.log_level)
50
+
51
+ # prepare data loaders
52
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
53
+ if 'imgs_per_gpu' in cfg.data:
54
+ logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
55
+ 'Please use "samples_per_gpu" instead')
56
+ if 'samples_per_gpu' in cfg.data:
57
+ logger.warning(
58
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
59
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
60
+ f'={cfg.data.imgs_per_gpu} is used in this experiments')
61
+ else:
62
+ logger.warning(
63
+ 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
64
+ f'{cfg.data.imgs_per_gpu} in this experiments')
65
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
66
+
67
+ data_loaders = [
68
+ build_dataloader(
69
+ ds,
70
+ cfg.data.samples_per_gpu,
71
+ cfg.data.workers_per_gpu,
72
+ # cfg.gpus will be ignored if distributed
73
+ len(cfg.gpu_ids),
74
+ dist=distributed,
75
+ seed=cfg.seed) for ds in dataset
76
+ ]
77
+
78
+ # build optimizer
79
+ optimizer = build_optimizer(model, cfg.optimizer)
80
+
81
+ # use apex fp16 optimizer
82
+ if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
83
+ if cfg.optimizer_config.get("use_fp16", False):
84
+ model, optimizer = apex.amp.initialize(
85
+ model.cuda(), optimizer, opt_level="O1")
86
+ for m in model.modules():
87
+ if hasattr(m, "fp16_enabled"):
88
+ m.fp16_enabled = True
89
+
90
+ # put model on gpus
91
+ if distributed:
92
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
93
+ # Sets the `find_unused_parameters` parameter in
94
+ # torch.nn.parallel.DistributedDataParallel
95
+ model = MMDistributedDataParallel(
96
+ model.cuda(),
97
+ device_ids=[torch.cuda.current_device()],
98
+ broadcast_buffers=False,
99
+ find_unused_parameters=find_unused_parameters)
100
+ else:
101
+ model = MMDataParallel(
102
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
103
+
104
+ if 'runner' not in cfg:
105
+ cfg.runner = {
106
+ 'type': 'EpochBasedRunner',
107
+ 'max_epochs': cfg.total_epochs
108
+ }
109
+ warnings.warn(
110
+ 'config is now expected to have a `runner` section, '
111
+ 'please set `runner` in your config.', UserWarning)
112
+ else:
113
+ if 'total_epochs' in cfg:
114
+ assert cfg.total_epochs == cfg.runner.max_epochs
115
+
116
+ # build runner
117
+ runner = build_runner(
118
+ cfg.runner,
119
+ default_args=dict(
120
+ model=model,
121
+ optimizer=optimizer,
122
+ work_dir=cfg.work_dir,
123
+ logger=logger,
124
+ meta=meta))
125
+
126
+ # an ugly workaround to make .log and .log.json filenames the same
127
+ runner.timestamp = timestamp
128
+
129
+ # fp16 setting
130
+ fp16_cfg = cfg.get('fp16', None)
131
+ if fp16_cfg is not None:
132
+ optimizer_config = Fp16OptimizerHook(
133
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
134
+ elif distributed and 'type' not in cfg.optimizer_config:
135
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
136
+ else:
137
+ optimizer_config = cfg.optimizer_config
138
+
139
+ # register hooks
140
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
141
+ cfg.checkpoint_config, cfg.log_config,
142
+ cfg.get('momentum_config', None))
143
+ if distributed:
144
+ if isinstance(runner, EpochBasedRunner):
145
+ runner.register_hook(DistSamplerSeedHook())
146
+
147
+ # register eval hooks
148
+ if validate:
149
+ # Support batch_size > 1 in validation
150
+ val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
151
+ if val_samples_per_gpu > 1:
152
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
153
+ cfg.data.val.pipeline = replace_ImageToTensor(
154
+ cfg.data.val.pipeline)
155
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
156
+ val_dataloader = build_dataloader(
157
+ val_dataset,
158
+ samples_per_gpu=val_samples_per_gpu,
159
+ workers_per_gpu=cfg.data.workers_per_gpu,
160
+ dist=distributed,
161
+ shuffle=False)
162
+ eval_cfg = cfg.get('evaluation', {})
163
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
164
+ eval_hook = DistEvalHook if distributed else EvalHook
165
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
166
+
167
+ # user-defined hooks
168
+ if cfg.get('custom_hooks', None):
169
+ custom_hooks = cfg.custom_hooks
170
+ assert isinstance(custom_hooks, list), \
171
+ f'custom_hooks expect list type, but got {type(custom_hooks)}'
172
+ for hook_cfg in cfg.custom_hooks:
173
+ assert isinstance(hook_cfg, dict), \
174
+ 'Each item in custom_hooks expects dict type, but got ' \
175
+ f'{type(hook_cfg)}'
176
+ hook_cfg = hook_cfg.copy()
177
+ priority = hook_cfg.pop('priority', 'NORMAL')
178
+ hook = build_from_cfg(hook_cfg, HOOKS)
179
+ runner.register_hook(hook, priority=priority)
180
+
181
+ if cfg.resume_from:
182
+ runner.resume(cfg.resume_from)
183
+ elif cfg.load_from:
184
+ runner.load_checkpoint(cfg.load_from)
185
+ runner.run(data_loaders, cfg.workflow)
mmdet/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .anchor import * # noqa: F401, F403
2
+ from .bbox import * # noqa: F401, F403
3
+ from .evaluation import * # noqa: F401, F403
4
+ from .export import * # noqa: F401, F403
5
+ from .mask import * # noqa: F401, F403
6
+ from .post_processing import * # noqa: F401, F403
7
+ from .utils import * # noqa: F401, F403
mmdet/core/anchor/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
2
+ YOLOAnchorGenerator)
3
+ from .builder import ANCHOR_GENERATORS, build_anchor_generator
4
+ from .point_generator import PointGenerator
5
+ from .utils import anchor_inside_flags, calc_region, images_to_levels
6
+
7
+ __all__ = [
8
+ 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
9
+ 'PointGenerator', 'images_to_levels', 'calc_region',
10
+ 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
11
+ ]
mmdet/core/anchor/anchor_generator.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn.modules.utils import _pair
5
+
6
+ from .builder import ANCHOR_GENERATORS
7
+
8
+
9
+ @ANCHOR_GENERATORS.register_module()
10
+ class AnchorGenerator(object):
11
+ """Standard anchor generator for 2D anchor-based detectors.
12
+
13
+ Args:
14
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
15
+ in multiple feature levels in order (w, h).
16
+ ratios (list[float]): The list of ratios between the height and width
17
+ of anchors in a single level.
18
+ scales (list[int] | None): Anchor scales for anchors in a single level.
19
+ It cannot be set at the same time if `octave_base_scale` and
20
+ `scales_per_octave` are set.
21
+ base_sizes (list[int] | None): The basic sizes
22
+ of anchors in multiple levels.
23
+ If None is given, strides will be used as base_sizes.
24
+ (If strides are non square, the shortest stride is taken.)
25
+ scale_major (bool): Whether to multiply scales first when generating
26
+ base anchors. If true, the anchors in the same row will have the
27
+ same scales. By default it is True in V2.0
28
+ octave_base_scale (int): The base scale of octave.
29
+ scales_per_octave (int): Number of scales for each octave.
30
+ `octave_base_scale` and `scales_per_octave` are usually used in
31
+ retinanet and the `scales` should be None when they are set.
32
+ centers (list[tuple[float, float]] | None): The centers of the anchor
33
+ relative to the feature grid center in multiple feature levels.
34
+ By default it is set to be None and not used. If a list of tuple of
35
+ float is given, they will be used to shift the centers of anchors.
36
+ center_offset (float): The offset of center in proportion to anchors'
37
+ width and height. By default it is 0 in V2.0.
38
+
39
+ Examples:
40
+ >>> from mmdet.core import AnchorGenerator
41
+ >>> self = AnchorGenerator([16], [1.], [1.], [9])
42
+ >>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
43
+ >>> print(all_anchors)
44
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
45
+ [11.5000, -4.5000, 20.5000, 4.5000],
46
+ [-4.5000, 11.5000, 4.5000, 20.5000],
47
+ [11.5000, 11.5000, 20.5000, 20.5000]])]
48
+ >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
49
+ >>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
50
+ >>> print(all_anchors)
51
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
52
+ [11.5000, -4.5000, 20.5000, 4.5000],
53
+ [-4.5000, 11.5000, 4.5000, 20.5000],
54
+ [11.5000, 11.5000, 20.5000, 20.5000]]), \
55
+ tensor([[-9., -9., 9., 9.]])]
56
+ """
57
+
58
+ def __init__(self,
59
+ strides,
60
+ ratios,
61
+ scales=None,
62
+ base_sizes=None,
63
+ scale_major=True,
64
+ octave_base_scale=None,
65
+ scales_per_octave=None,
66
+ centers=None,
67
+ center_offset=0.):
68
+ # check center and center_offset
69
+ if center_offset != 0:
70
+ assert centers is None, 'center cannot be set when center_offset' \
71
+ f'!=0, {centers} is given.'
72
+ if not (0 <= center_offset <= 1):
73
+ raise ValueError('center_offset should be in range [0, 1], '
74
+ f'{center_offset} is given.')
75
+ if centers is not None:
76
+ assert len(centers) == len(strides), \
77
+ 'The number of strides should be the same as centers, got ' \
78
+ f'{strides} and {centers}'
79
+
80
+ # calculate base sizes of anchors
81
+ self.strides = [_pair(stride) for stride in strides]
82
+ self.base_sizes = [min(stride) for stride in self.strides
83
+ ] if base_sizes is None else base_sizes
84
+ assert len(self.base_sizes) == len(self.strides), \
85
+ 'The number of strides should be the same as base sizes, got ' \
86
+ f'{self.strides} and {self.base_sizes}'
87
+
88
+ # calculate scales of anchors
89
+ assert ((octave_base_scale is not None
90
+ and scales_per_octave is not None) ^ (scales is not None)), \
91
+ 'scales and octave_base_scale with scales_per_octave cannot' \
92
+ ' be set at the same time'
93
+ if scales is not None:
94
+ self.scales = torch.Tensor(scales)
95
+ elif octave_base_scale is not None and scales_per_octave is not None:
96
+ octave_scales = np.array(
97
+ [2**(i / scales_per_octave) for i in range(scales_per_octave)])
98
+ scales = octave_scales * octave_base_scale
99
+ self.scales = torch.Tensor(scales)
100
+ else:
101
+ raise ValueError('Either scales or octave_base_scale with '
102
+ 'scales_per_octave should be set')
103
+
104
+ self.octave_base_scale = octave_base_scale
105
+ self.scales_per_octave = scales_per_octave
106
+ self.ratios = torch.Tensor(ratios)
107
+ self.scale_major = scale_major
108
+ self.centers = centers
109
+ self.center_offset = center_offset
110
+ self.base_anchors = self.gen_base_anchors()
111
+
112
+ @property
113
+ def num_base_anchors(self):
114
+ """list[int]: total number of base anchors in a feature grid"""
115
+ return [base_anchors.size(0) for base_anchors in self.base_anchors]
116
+
117
+ @property
118
+ def num_levels(self):
119
+ """int: number of feature levels that the generator will be applied"""
120
+ return len(self.strides)
121
+
122
+ def gen_base_anchors(self):
123
+ """Generate base anchors.
124
+
125
+ Returns:
126
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
127
+ feature levels.
128
+ """
129
+ multi_level_base_anchors = []
130
+ for i, base_size in enumerate(self.base_sizes):
131
+ center = None
132
+ if self.centers is not None:
133
+ center = self.centers[i]
134
+ multi_level_base_anchors.append(
135
+ self.gen_single_level_base_anchors(
136
+ base_size,
137
+ scales=self.scales,
138
+ ratios=self.ratios,
139
+ center=center))
140
+ return multi_level_base_anchors
141
+
142
+ def gen_single_level_base_anchors(self,
143
+ base_size,
144
+ scales,
145
+ ratios,
146
+ center=None):
147
+ """Generate base anchors of a single level.
148
+
149
+ Args:
150
+ base_size (int | float): Basic size of an anchor.
151
+ scales (torch.Tensor): Scales of the anchor.
152
+ ratios (torch.Tensor): The ratio between between the height
153
+ and width of anchors in a single level.
154
+ center (tuple[float], optional): The center of the base anchor
155
+ related to a single feature grid. Defaults to None.
156
+
157
+ Returns:
158
+ torch.Tensor: Anchors in a single-level feature maps.
159
+ """
160
+ w = base_size
161
+ h = base_size
162
+ if center is None:
163
+ x_center = self.center_offset * w
164
+ y_center = self.center_offset * h
165
+ else:
166
+ x_center, y_center = center
167
+
168
+ h_ratios = torch.sqrt(ratios)
169
+ w_ratios = 1 / h_ratios
170
+ if self.scale_major:
171
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
172
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
173
+ else:
174
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
175
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
176
+
177
+ # use float anchor and the anchor's center is aligned with the
178
+ # pixel center
179
+ base_anchors = [
180
+ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
181
+ y_center + 0.5 * hs
182
+ ]
183
+ base_anchors = torch.stack(base_anchors, dim=-1)
184
+
185
+ return base_anchors
186
+
187
+ def _meshgrid(self, x, y, row_major=True):
188
+ """Generate mesh grid of x and y.
189
+
190
+ Args:
191
+ x (torch.Tensor): Grids of x dimension.
192
+ y (torch.Tensor): Grids of y dimension.
193
+ row_major (bool, optional): Whether to return y grids first.
194
+ Defaults to True.
195
+
196
+ Returns:
197
+ tuple[torch.Tensor]: The mesh grids of x and y.
198
+ """
199
+ # use shape instead of len to keep tracing while exporting to onnx
200
+ xx = x.repeat(y.shape[0])
201
+ yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
202
+ if row_major:
203
+ return xx, yy
204
+ else:
205
+ return yy, xx
206
+
207
+ def grid_anchors(self, featmap_sizes, device='cuda'):
208
+ """Generate grid anchors in multiple feature levels.
209
+
210
+ Args:
211
+ featmap_sizes (list[tuple]): List of feature map sizes in
212
+ multiple feature levels.
213
+ device (str): Device where the anchors will be put on.
214
+
215
+ Return:
216
+ list[torch.Tensor]: Anchors in multiple feature levels. \
217
+ The sizes of each tensor should be [N, 4], where \
218
+ N = width * height * num_base_anchors, width and height \
219
+ are the sizes of the corresponding feature level, \
220
+ num_base_anchors is the number of anchors for that level.
221
+ """
222
+ assert self.num_levels == len(featmap_sizes)
223
+ multi_level_anchors = []
224
+ for i in range(self.num_levels):
225
+ anchors = self.single_level_grid_anchors(
226
+ self.base_anchors[i].to(device),
227
+ featmap_sizes[i],
228
+ self.strides[i],
229
+ device=device)
230
+ multi_level_anchors.append(anchors)
231
+ return multi_level_anchors
232
+
233
+ def single_level_grid_anchors(self,
234
+ base_anchors,
235
+ featmap_size,
236
+ stride=(16, 16),
237
+ device='cuda'):
238
+ """Generate grid anchors of a single level.
239
+
240
+ Note:
241
+ This function is usually called by method ``self.grid_anchors``.
242
+
243
+ Args:
244
+ base_anchors (torch.Tensor): The base anchors of a feature grid.
245
+ featmap_size (tuple[int]): Size of the feature maps.
246
+ stride (tuple[int], optional): Stride of the feature map in order
247
+ (w, h). Defaults to (16, 16).
248
+ device (str, optional): Device the tensor will be put on.
249
+ Defaults to 'cuda'.
250
+
251
+ Returns:
252
+ torch.Tensor: Anchors in the overall feature maps.
253
+ """
254
+ # keep as Tensor, so that we can covert to ONNX correctly
255
+ feat_h, feat_w = featmap_size
256
+ shift_x = torch.arange(0, feat_w, device=device) * stride[0]
257
+ shift_y = torch.arange(0, feat_h, device=device) * stride[1]
258
+
259
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
260
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
261
+ shifts = shifts.type_as(base_anchors)
262
+ # first feat_w elements correspond to the first row of shifts
263
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
264
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
265
+
266
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
267
+ all_anchors = all_anchors.view(-1, 4)
268
+ # first A rows correspond to A anchors of (0, 0) in feature map,
269
+ # then (0, 1), (0, 2), ...
270
+ return all_anchors
271
+
272
+ def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
273
+ """Generate valid flags of anchors in multiple feature levels.
274
+
275
+ Args:
276
+ featmap_sizes (list(tuple)): List of feature map sizes in
277
+ multiple feature levels.
278
+ pad_shape (tuple): The padded shape of the image.
279
+ device (str): Device where the anchors will be put on.
280
+
281
+ Return:
282
+ list(torch.Tensor): Valid flags of anchors in multiple levels.
283
+ """
284
+ assert self.num_levels == len(featmap_sizes)
285
+ multi_level_flags = []
286
+ for i in range(self.num_levels):
287
+ anchor_stride = self.strides[i]
288
+ feat_h, feat_w = featmap_sizes[i]
289
+ h, w = pad_shape[:2]
290
+ valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
291
+ valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
292
+ flags = self.single_level_valid_flags((feat_h, feat_w),
293
+ (valid_feat_h, valid_feat_w),
294
+ self.num_base_anchors[i],
295
+ device=device)
296
+ multi_level_flags.append(flags)
297
+ return multi_level_flags
298
+
299
+ def single_level_valid_flags(self,
300
+ featmap_size,
301
+ valid_size,
302
+ num_base_anchors,
303
+ device='cuda'):
304
+ """Generate the valid flags of anchor in a single feature map.
305
+
306
+ Args:
307
+ featmap_size (tuple[int]): The size of feature maps.
308
+ valid_size (tuple[int]): The valid size of the feature maps.
309
+ num_base_anchors (int): The number of base anchors.
310
+ device (str, optional): Device where the flags will be put on.
311
+ Defaults to 'cuda'.
312
+
313
+ Returns:
314
+ torch.Tensor: The valid flags of each anchor in a single level \
315
+ feature map.
316
+ """
317
+ feat_h, feat_w = featmap_size
318
+ valid_h, valid_w = valid_size
319
+ assert valid_h <= feat_h and valid_w <= feat_w
320
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
321
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
322
+ valid_x[:valid_w] = 1
323
+ valid_y[:valid_h] = 1
324
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
325
+ valid = valid_xx & valid_yy
326
+ valid = valid[:, None].expand(valid.size(0),
327
+ num_base_anchors).contiguous().view(-1)
328
+ return valid
329
+
330
+ def __repr__(self):
331
+ """str: a string that describes the module"""
332
+ indent_str = ' '
333
+ repr_str = self.__class__.__name__ + '(\n'
334
+ repr_str += f'{indent_str}strides={self.strides},\n'
335
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
336
+ repr_str += f'{indent_str}scales={self.scales},\n'
337
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
338
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
339
+ repr_str += f'{indent_str}octave_base_scale='
340
+ repr_str += f'{self.octave_base_scale},\n'
341
+ repr_str += f'{indent_str}scales_per_octave='
342
+ repr_str += f'{self.scales_per_octave},\n'
343
+ repr_str += f'{indent_str}num_levels={self.num_levels}\n'
344
+ repr_str += f'{indent_str}centers={self.centers},\n'
345
+ repr_str += f'{indent_str}center_offset={self.center_offset})'
346
+ return repr_str
347
+
348
+
349
+ @ANCHOR_GENERATORS.register_module()
350
+ class SSDAnchorGenerator(AnchorGenerator):
351
+ """Anchor generator for SSD.
352
+
353
+ Args:
354
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
355
+ in multiple feature levels.
356
+ ratios (list[float]): The list of ratios between the height and width
357
+ of anchors in a single level.
358
+ basesize_ratio_range (tuple(float)): Ratio range of anchors.
359
+ input_size (int): Size of feature map, 300 for SSD300,
360
+ 512 for SSD512.
361
+ scale_major (bool): Whether to multiply scales first when generating
362
+ base anchors. If true, the anchors in the same row will have the
363
+ same scales. It is always set to be False in SSD.
364
+ """
365
+
366
+ def __init__(self,
367
+ strides,
368
+ ratios,
369
+ basesize_ratio_range,
370
+ input_size=300,
371
+ scale_major=True):
372
+ assert len(strides) == len(ratios)
373
+ assert mmcv.is_tuple_of(basesize_ratio_range, float)
374
+
375
+ self.strides = [_pair(stride) for stride in strides]
376
+ self.input_size = input_size
377
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
378
+ for stride in self.strides]
379
+ self.basesize_ratio_range = basesize_ratio_range
380
+
381
+ # calculate anchor ratios and sizes
382
+ min_ratio, max_ratio = basesize_ratio_range
383
+ min_ratio = int(min_ratio * 100)
384
+ max_ratio = int(max_ratio * 100)
385
+ step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
386
+ min_sizes = []
387
+ max_sizes = []
388
+ for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
389
+ min_sizes.append(int(self.input_size * ratio / 100))
390
+ max_sizes.append(int(self.input_size * (ratio + step) / 100))
391
+ if self.input_size == 300:
392
+ if basesize_ratio_range[0] == 0.15: # SSD300 COCO
393
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
394
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
395
+ elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
396
+ min_sizes.insert(0, int(self.input_size * 10 / 100))
397
+ max_sizes.insert(0, int(self.input_size * 20 / 100))
398
+ else:
399
+ raise ValueError(
400
+ 'basesize_ratio_range[0] should be either 0.15'
401
+ 'or 0.2 when input_size is 300, got '
402
+ f'{basesize_ratio_range[0]}.')
403
+ elif self.input_size == 512:
404
+ if basesize_ratio_range[0] == 0.1: # SSD512 COCO
405
+ min_sizes.insert(0, int(self.input_size * 4 / 100))
406
+ max_sizes.insert(0, int(self.input_size * 10 / 100))
407
+ elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
408
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
409
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
410
+ else:
411
+ raise ValueError('basesize_ratio_range[0] should be either 0.1'
412
+ 'or 0.15 when input_size is 512, got'
413
+ f' {basesize_ratio_range[0]}.')
414
+ else:
415
+ raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
416
+ f', got {self.input_size}.')
417
+
418
+ anchor_ratios = []
419
+ anchor_scales = []
420
+ for k in range(len(self.strides)):
421
+ scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
422
+ anchor_ratio = [1.]
423
+ for r in ratios[k]:
424
+ anchor_ratio += [1 / r, r] # 4 or 6 ratio
425
+ anchor_ratios.append(torch.Tensor(anchor_ratio))
426
+ anchor_scales.append(torch.Tensor(scales))
427
+
428
+ self.base_sizes = min_sizes
429
+ self.scales = anchor_scales
430
+ self.ratios = anchor_ratios
431
+ self.scale_major = scale_major
432
+ self.center_offset = 0
433
+ self.base_anchors = self.gen_base_anchors()
434
+
435
+ def gen_base_anchors(self):
436
+ """Generate base anchors.
437
+
438
+ Returns:
439
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
440
+ feature levels.
441
+ """
442
+ multi_level_base_anchors = []
443
+ for i, base_size in enumerate(self.base_sizes):
444
+ base_anchors = self.gen_single_level_base_anchors(
445
+ base_size,
446
+ scales=self.scales[i],
447
+ ratios=self.ratios[i],
448
+ center=self.centers[i])
449
+ indices = list(range(len(self.ratios[i])))
450
+ indices.insert(1, len(indices))
451
+ base_anchors = torch.index_select(base_anchors, 0,
452
+ torch.LongTensor(indices))
453
+ multi_level_base_anchors.append(base_anchors)
454
+ return multi_level_base_anchors
455
+
456
+ def __repr__(self):
457
+ """str: a string that describes the module"""
458
+ indent_str = ' '
459
+ repr_str = self.__class__.__name__ + '(\n'
460
+ repr_str += f'{indent_str}strides={self.strides},\n'
461
+ repr_str += f'{indent_str}scales={self.scales},\n'
462
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
463
+ repr_str += f'{indent_str}input_size={self.input_size},\n'
464
+ repr_str += f'{indent_str}scales={self.scales},\n'
465
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
466
+ repr_str += f'{indent_str}num_levels={self.num_levels},\n'
467
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
468
+ repr_str += f'{indent_str}basesize_ratio_range='
469
+ repr_str += f'{self.basesize_ratio_range})'
470
+ return repr_str
471
+
472
+
473
+ @ANCHOR_GENERATORS.register_module()
474
+ class LegacyAnchorGenerator(AnchorGenerator):
475
+ """Legacy anchor generator used in MMDetection V1.x.
476
+
477
+ Note:
478
+ Difference to the V2.0 anchor generator:
479
+
480
+ 1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
481
+ 2. The width/height are minused by 1 when calculating the anchors' \
482
+ centers and corners to meet the V1.x coordinate system.
483
+ 3. The anchors' corners are quantized.
484
+
485
+ Args:
486
+ strides (list[int] | list[tuple[int]]): Strides of anchors
487
+ in multiple feature levels.
488
+ ratios (list[float]): The list of ratios between the height and width
489
+ of anchors in a single level.
490
+ scales (list[int] | None): Anchor scales for anchors in a single level.
491
+ It cannot be set at the same time if `octave_base_scale` and
492
+ `scales_per_octave` are set.
493
+ base_sizes (list[int]): The basic sizes of anchors in multiple levels.
494
+ If None is given, strides will be used to generate base_sizes.
495
+ scale_major (bool): Whether to multiply scales first when generating
496
+ base anchors. If true, the anchors in the same row will have the
497
+ same scales. By default it is True in V2.0
498
+ octave_base_scale (int): The base scale of octave.
499
+ scales_per_octave (int): Number of scales for each octave.
500
+ `octave_base_scale` and `scales_per_octave` are usually used in
501
+ retinanet and the `scales` should be None when they are set.
502
+ centers (list[tuple[float, float]] | None): The centers of the anchor
503
+ relative to the feature grid center in multiple feature levels.
504
+ By default it is set to be None and not used. It a list of float
505
+ is given, this list will be used to shift the centers of anchors.
506
+ center_offset (float): The offset of center in propotion to anchors'
507
+ width and height. By default it is 0.5 in V2.0 but it should be 0.5
508
+ in v1.x models.
509
+
510
+ Examples:
511
+ >>> from mmdet.core import LegacyAnchorGenerator
512
+ >>> self = LegacyAnchorGenerator(
513
+ >>> [16], [1.], [1.], [9], center_offset=0.5)
514
+ >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
515
+ >>> print(all_anchors)
516
+ [tensor([[ 0., 0., 8., 8.],
517
+ [16., 0., 24., 8.],
518
+ [ 0., 16., 8., 24.],
519
+ [16., 16., 24., 24.]])]
520
+ """
521
+
522
+ def gen_single_level_base_anchors(self,
523
+ base_size,
524
+ scales,
525
+ ratios,
526
+ center=None):
527
+ """Generate base anchors of a single level.
528
+
529
+ Note:
530
+ The width/height of anchors are minused by 1 when calculating \
531
+ the centers and corners to meet the V1.x coordinate system.
532
+
533
+ Args:
534
+ base_size (int | float): Basic size of an anchor.
535
+ scales (torch.Tensor): Scales of the anchor.
536
+ ratios (torch.Tensor): The ratio between between the height.
537
+ and width of anchors in a single level.
538
+ center (tuple[float], optional): The center of the base anchor
539
+ related to a single feature grid. Defaults to None.
540
+
541
+ Returns:
542
+ torch.Tensor: Anchors in a single-level feature map.
543
+ """
544
+ w = base_size
545
+ h = base_size
546
+ if center is None:
547
+ x_center = self.center_offset * (w - 1)
548
+ y_center = self.center_offset * (h - 1)
549
+ else:
550
+ x_center, y_center = center
551
+
552
+ h_ratios = torch.sqrt(ratios)
553
+ w_ratios = 1 / h_ratios
554
+ if self.scale_major:
555
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
556
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
557
+ else:
558
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
559
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
560
+
561
+ # use float anchor and the anchor's center is aligned with the
562
+ # pixel center
563
+ base_anchors = [
564
+ x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
565
+ x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
566
+ ]
567
+ base_anchors = torch.stack(base_anchors, dim=-1).round()
568
+
569
+ return base_anchors
570
+
571
+
572
+ @ANCHOR_GENERATORS.register_module()
573
+ class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
574
+ """Legacy anchor generator used in MMDetection V1.x.
575
+
576
+ The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
577
+ can be found in `LegacyAnchorGenerator`.
578
+ """
579
+
580
+ def __init__(self,
581
+ strides,
582
+ ratios,
583
+ basesize_ratio_range,
584
+ input_size=300,
585
+ scale_major=True):
586
+ super(LegacySSDAnchorGenerator,
587
+ self).__init__(strides, ratios, basesize_ratio_range, input_size,
588
+ scale_major)
589
+ self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
590
+ for stride in strides]
591
+ self.base_anchors = self.gen_base_anchors()
592
+
593
+
594
+ @ANCHOR_GENERATORS.register_module()
595
+ class YOLOAnchorGenerator(AnchorGenerator):
596
+ """Anchor generator for YOLO.
597
+
598
+ Args:
599
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
600
+ in multiple feature levels.
601
+ base_sizes (list[list[tuple[int, int]]]): The basic sizes
602
+ of anchors in multiple levels.
603
+ """
604
+
605
+ def __init__(self, strides, base_sizes):
606
+ self.strides = [_pair(stride) for stride in strides]
607
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
608
+ for stride in self.strides]
609
+ self.base_sizes = []
610
+ num_anchor_per_level = len(base_sizes[0])
611
+ for base_sizes_per_level in base_sizes:
612
+ assert num_anchor_per_level == len(base_sizes_per_level)
613
+ self.base_sizes.append(
614
+ [_pair(base_size) for base_size in base_sizes_per_level])
615
+ self.base_anchors = self.gen_base_anchors()
616
+
617
+ @property
618
+ def num_levels(self):
619
+ """int: number of feature levels that the generator will be applied"""
620
+ return len(self.base_sizes)
621
+
622
+ def gen_base_anchors(self):
623
+ """Generate base anchors.
624
+
625
+ Returns:
626
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
627
+ feature levels.
628
+ """
629
+ multi_level_base_anchors = []
630
+ for i, base_sizes_per_level in enumerate(self.base_sizes):
631
+ center = None
632
+ if self.centers is not None:
633
+ center = self.centers[i]
634
+ multi_level_base_anchors.append(
635
+ self.gen_single_level_base_anchors(base_sizes_per_level,
636
+ center))
637
+ return multi_level_base_anchors
638
+
639
+ def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
640
+ """Generate base anchors of a single level.
641
+
642
+ Args:
643
+ base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
644
+ anchors.
645
+ center (tuple[float], optional): The center of the base anchor
646
+ related to a single feature grid. Defaults to None.
647
+
648
+ Returns:
649
+ torch.Tensor: Anchors in a single-level feature maps.
650
+ """
651
+ x_center, y_center = center
652
+ base_anchors = []
653
+ for base_size in base_sizes_per_level:
654
+ w, h = base_size
655
+
656
+ # use float anchor and the anchor's center is aligned with the
657
+ # pixel center
658
+ base_anchor = torch.Tensor([
659
+ x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
660
+ y_center + 0.5 * h
661
+ ])
662
+ base_anchors.append(base_anchor)
663
+ base_anchors = torch.stack(base_anchors, dim=0)
664
+
665
+ return base_anchors
666
+
667
+ def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
668
+ """Generate responsible anchor flags of grid cells in multiple scales.
669
+
670
+ Args:
671
+ featmap_sizes (list(tuple)): List of feature map sizes in multiple
672
+ feature levels.
673
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
674
+ device (str): Device where the anchors will be put on.
675
+
676
+ Return:
677
+ list(torch.Tensor): responsible flags of anchors in multiple level
678
+ """
679
+ assert self.num_levels == len(featmap_sizes)
680
+ multi_level_responsible_flags = []
681
+ for i in range(self.num_levels):
682
+ anchor_stride = self.strides[i]
683
+ flags = self.single_level_responsible_flags(
684
+ featmap_sizes[i],
685
+ gt_bboxes,
686
+ anchor_stride,
687
+ self.num_base_anchors[i],
688
+ device=device)
689
+ multi_level_responsible_flags.append(flags)
690
+ return multi_level_responsible_flags
691
+
692
+ def single_level_responsible_flags(self,
693
+ featmap_size,
694
+ gt_bboxes,
695
+ stride,
696
+ num_base_anchors,
697
+ device='cuda'):
698
+ """Generate the responsible flags of anchor in a single feature map.
699
+
700
+ Args:
701
+ featmap_size (tuple[int]): The size of feature maps.
702
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
703
+ stride (tuple(int)): stride of current level
704
+ num_base_anchors (int): The number of base anchors.
705
+ device (str, optional): Device where the flags will be put on.
706
+ Defaults to 'cuda'.
707
+
708
+ Returns:
709
+ torch.Tensor: The valid flags of each anchor in a single level \
710
+ feature map.
711
+ """
712
+ feat_h, feat_w = featmap_size
713
+ gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
714
+ gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
715
+ gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
716
+ gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
717
+
718
+ # row major indexing
719
+ gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
720
+
721
+ responsible_grid = torch.zeros(
722
+ feat_h * feat_w, dtype=torch.uint8, device=device)
723
+ responsible_grid[gt_bboxes_grid_idx] = 1
724
+
725
+ responsible_grid = responsible_grid[:, None].expand(
726
+ responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
727
+ return responsible_grid
mmdet/core/anchor/builder.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from mmcv.utils import Registry, build_from_cfg
2
+
3
+ ANCHOR_GENERATORS = Registry('Anchor generator')
4
+
5
+
6
+ def build_anchor_generator(cfg, default_args=None):
7
+ return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
mmdet/core/anchor/point_generator.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .builder import ANCHOR_GENERATORS
4
+
5
+
6
+ @ANCHOR_GENERATORS.register_module()
7
+ class PointGenerator(object):
8
+
9
+ def _meshgrid(self, x, y, row_major=True):
10
+ xx = x.repeat(len(y))
11
+ yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
12
+ if row_major:
13
+ return xx, yy
14
+ else:
15
+ return yy, xx
16
+
17
+ def grid_points(self, featmap_size, stride=16, device='cuda'):
18
+ feat_h, feat_w = featmap_size
19
+ shift_x = torch.arange(0., feat_w, device=device) * stride
20
+ shift_y = torch.arange(0., feat_h, device=device) * stride
21
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
22
+ stride = shift_x.new_full((shift_xx.shape[0], ), stride)
23
+ shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
24
+ all_points = shifts.to(device)
25
+ return all_points
26
+
27
+ def valid_flags(self, featmap_size, valid_size, device='cuda'):
28
+ feat_h, feat_w = featmap_size
29
+ valid_h, valid_w = valid_size
30
+ assert valid_h <= feat_h and valid_w <= feat_w
31
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
32
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
33
+ valid_x[:valid_w] = 1
34
+ valid_y[:valid_h] = 1
35
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
36
+ valid = valid_xx & valid_yy
37
+ return valid
mmdet/core/anchor/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def images_to_levels(target, num_levels):
5
+ """Convert targets by image to targets by feature level.
6
+
7
+ [target_img0, target_img1] -> [target_level0, target_level1, ...]
8
+ """
9
+ target = torch.stack(target, 0)
10
+ level_targets = []
11
+ start = 0
12
+ for n in num_levels:
13
+ end = start + n
14
+ # level_targets.append(target[:, start:end].squeeze(0))
15
+ level_targets.append(target[:, start:end])
16
+ start = end
17
+ return level_targets
18
+
19
+
20
+ def anchor_inside_flags(flat_anchors,
21
+ valid_flags,
22
+ img_shape,
23
+ allowed_border=0):
24
+ """Check whether the anchors are inside the border.
25
+
26
+ Args:
27
+ flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
28
+ valid_flags (torch.Tensor): An existing valid flags of anchors.
29
+ img_shape (tuple(int)): Shape of current image.
30
+ allowed_border (int, optional): The border to allow the valid anchor.
31
+ Defaults to 0.
32
+
33
+ Returns:
34
+ torch.Tensor: Flags indicating whether the anchors are inside a \
35
+ valid range.
36
+ """
37
+ img_h, img_w = img_shape[:2]
38
+ if allowed_border >= 0:
39
+ inside_flags = valid_flags & \
40
+ (flat_anchors[:, 0] >= -allowed_border) & \
41
+ (flat_anchors[:, 1] >= -allowed_border) & \
42
+ (flat_anchors[:, 2] < img_w + allowed_border) & \
43
+ (flat_anchors[:, 3] < img_h + allowed_border)
44
+ else:
45
+ inside_flags = valid_flags
46
+ return inside_flags
47
+
48
+
49
+ def calc_region(bbox, ratio, featmap_size=None):
50
+ """Calculate a proportional bbox region.
51
+
52
+ The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
53
+
54
+ Args:
55
+ bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
56
+ ratio (float): Ratio of the output region.
57
+ featmap_size (tuple): Feature map size used for clipping the boundary.
58
+
59
+ Returns:
60
+ tuple: x1, y1, x2, y2
61
+ """
62
+ x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
63
+ y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
64
+ x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
65
+ y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
66
+ if featmap_size is not None:
67
+ x1 = x1.clamp(min=0, max=featmap_size[1])
68
+ y1 = y1.clamp(min=0, max=featmap_size[0])
69
+ x2 = x2.clamp(min=0, max=featmap_size[1])
70
+ y2 = y2.clamp(min=0, max=featmap_size[0])
71
+ return (x1, y1, x2, y2)
mmdet/core/bbox/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
2
+ MaxIoUAssigner, RegionAssigner)
3
+ from .builder import build_assigner, build_bbox_coder, build_sampler
4
+ from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
5
+ TBLRBBoxCoder)
6
+ from .iou_calculators import BboxOverlaps2D, bbox_overlaps
7
+ from .samplers import (BaseSampler, CombinedSampler,
8
+ InstanceBalancedPosSampler, IoUBalancedNegSampler,
9
+ OHEMSampler, PseudoSampler, RandomSampler,
10
+ SamplingResult, ScoreHLRSampler)
11
+ from .transforms import (bbox2distance, bbox2result, bbox2roi,
12
+ bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
13
+ bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
14
+ distance2bbox, roi2bbox)
15
+
16
+ __all__ = [
17
+ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
18
+ 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
19
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
20
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
21
+ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
22
+ 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
23
+ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
24
+ 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
25
+ 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
26
+ 'RegionAssigner'
27
+ ]
mmdet/core/bbox/assigners/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .approx_max_iou_assigner import ApproxMaxIoUAssigner
2
+ from .assign_result import AssignResult
3
+ from .atss_assigner import ATSSAssigner
4
+ from .base_assigner import BaseAssigner
5
+ from .center_region_assigner import CenterRegionAssigner
6
+ from .grid_assigner import GridAssigner
7
+ from .hungarian_assigner import HungarianAssigner
8
+ from .max_iou_assigner import MaxIoUAssigner
9
+ from .point_assigner import PointAssigner
10
+ from .region_assigner import RegionAssigner
11
+
12
+ __all__ = [
13
+ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
14
+ 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
15
+ 'HungarianAssigner', 'RegionAssigner'
16
+ ]
mmdet/core/bbox/assigners/approx_max_iou_assigner.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .max_iou_assigner import MaxIoUAssigner
6
+
7
+
8
+ @BBOX_ASSIGNERS.register_module()
9
+ class ApproxMaxIoUAssigner(MaxIoUAssigner):
10
+ """Assign a corresponding gt bbox or background to each bbox.
11
+
12
+ Each proposals will be assigned with an integer indicating the ground truth
13
+ index. (semi-positive index: gt label (0-based), -1: background)
14
+
15
+ - -1: negative sample, no assigned gt
16
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
17
+
18
+ Args:
19
+ pos_iou_thr (float): IoU threshold for positive bboxes.
20
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
21
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
22
+ positive bbox. Positive samples can have smaller IoU than
23
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
24
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
25
+ highest overlap with some gt to that gt.
26
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
27
+ `gt_bboxes_ignore` is specified). Negative values mean not
28
+ ignoring any bboxes.
29
+ ignore_wrt_candidates (bool): Whether to compute the iof between
30
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
31
+ match_low_quality (bool): Whether to allow quality matches. This is
32
+ usually allowed for RPN and single stage detectors, but not allowed
33
+ in the second stage.
34
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
35
+ assign. When the number of gt is above this threshold, will assign
36
+ on CPU device. Negative values mean not assign on CPU.
37
+ """
38
+
39
+ def __init__(self,
40
+ pos_iou_thr,
41
+ neg_iou_thr,
42
+ min_pos_iou=.0,
43
+ gt_max_assign_all=True,
44
+ ignore_iof_thr=-1,
45
+ ignore_wrt_candidates=True,
46
+ match_low_quality=True,
47
+ gpu_assign_thr=-1,
48
+ iou_calculator=dict(type='BboxOverlaps2D')):
49
+ self.pos_iou_thr = pos_iou_thr
50
+ self.neg_iou_thr = neg_iou_thr
51
+ self.min_pos_iou = min_pos_iou
52
+ self.gt_max_assign_all = gt_max_assign_all
53
+ self.ignore_iof_thr = ignore_iof_thr
54
+ self.ignore_wrt_candidates = ignore_wrt_candidates
55
+ self.gpu_assign_thr = gpu_assign_thr
56
+ self.match_low_quality = match_low_quality
57
+ self.iou_calculator = build_iou_calculator(iou_calculator)
58
+
59
+ def assign(self,
60
+ approxs,
61
+ squares,
62
+ approxs_per_octave,
63
+ gt_bboxes,
64
+ gt_bboxes_ignore=None,
65
+ gt_labels=None):
66
+ """Assign gt to approxs.
67
+
68
+ This method assign a gt bbox to each group of approxs (bboxes),
69
+ each group of approxs is represent by a base approx (bbox) and
70
+ will be assigned with -1, or a semi-positive number.
71
+ background_label (-1) means negative sample,
72
+ semi-positive number is the index (0-based) of assigned gt.
73
+ The assignment is done in following steps, the order matters.
74
+
75
+ 1. assign every bbox to background_label (-1)
76
+ 2. use the max IoU of each group of approxs to assign
77
+ 2. assign proposals whose iou with all gts < neg_iou_thr to background
78
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
79
+ assign it to that bbox
80
+ 4. for each gt bbox, assign its nearest proposals (may be more than
81
+ one) to itself
82
+
83
+ Args:
84
+ approxs (Tensor): Bounding boxes to be assigned,
85
+ shape(approxs_per_octave*n, 4).
86
+ squares (Tensor): Base Bounding boxes to be assigned,
87
+ shape(n, 4).
88
+ approxs_per_octave (int): number of approxs per octave
89
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
90
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
91
+ labelled as `ignored`, e.g., crowd boxes in COCO.
92
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
93
+
94
+ Returns:
95
+ :obj:`AssignResult`: The assign result.
96
+ """
97
+ num_squares = squares.size(0)
98
+ num_gts = gt_bboxes.size(0)
99
+
100
+ if num_squares == 0 or num_gts == 0:
101
+ # No predictions and/or truth, return empty assignment
102
+ overlaps = approxs.new(num_gts, num_squares)
103
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
104
+ return assign_result
105
+
106
+ # re-organize anchors by approxs_per_octave x num_squares
107
+ approxs = torch.transpose(
108
+ approxs.view(num_squares, approxs_per_octave, 4), 0,
109
+ 1).contiguous().view(-1, 4)
110
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
111
+ num_gts > self.gpu_assign_thr) else False
112
+ # compute overlap and assign gt on CPU when number of GT is large
113
+ if assign_on_cpu:
114
+ device = approxs.device
115
+ approxs = approxs.cpu()
116
+ gt_bboxes = gt_bboxes.cpu()
117
+ if gt_bboxes_ignore is not None:
118
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
119
+ if gt_labels is not None:
120
+ gt_labels = gt_labels.cpu()
121
+ all_overlaps = self.iou_calculator(approxs, gt_bboxes)
122
+
123
+ overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
124
+ num_gts).max(dim=0)
125
+ overlaps = torch.transpose(overlaps, 0, 1)
126
+
127
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
128
+ and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
129
+ if self.ignore_wrt_candidates:
130
+ ignore_overlaps = self.iou_calculator(
131
+ squares, gt_bboxes_ignore, mode='iof')
132
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
133
+ else:
134
+ ignore_overlaps = self.iou_calculator(
135
+ gt_bboxes_ignore, squares, mode='iof')
136
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
137
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
138
+
139
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
140
+ if assign_on_cpu:
141
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
142
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
143
+ if assign_result.labels is not None:
144
+ assign_result.labels = assign_result.labels.to(device)
145
+ return assign_result
mmdet/core/bbox/assigners/assign_result.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmdet.utils import util_mixins
4
+
5
+
6
+ class AssignResult(util_mixins.NiceRepr):
7
+ """Stores assignments between predicted and truth boxes.
8
+
9
+ Attributes:
10
+ num_gts (int): the number of truth boxes considered when computing this
11
+ assignment
12
+
13
+ gt_inds (LongTensor): for each predicted box indicates the 1-based
14
+ index of the assigned truth box. 0 means unassigned and -1 means
15
+ ignore.
16
+
17
+ max_overlaps (FloatTensor): the iou between the predicted box and its
18
+ assigned truth box.
19
+
20
+ labels (None | LongTensor): If specified, for each predicted box
21
+ indicates the category label of the assigned truth box.
22
+
23
+ Example:
24
+ >>> # An assign result between 4 predicted boxes and 9 true boxes
25
+ >>> # where only two boxes were assigned.
26
+ >>> num_gts = 9
27
+ >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
28
+ >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
29
+ >>> labels = torch.LongTensor([0, 3, 4, 0])
30
+ >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
31
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
32
+ <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
33
+ labels.shape=(4,))>
34
+ >>> # Force addition of gt labels (when adding gt as proposals)
35
+ >>> new_labels = torch.LongTensor([3, 4, 5])
36
+ >>> self.add_gt_(new_labels)
37
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
38
+ <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
39
+ labels.shape=(7,))>
40
+ """
41
+
42
+ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
43
+ self.num_gts = num_gts
44
+ self.gt_inds = gt_inds
45
+ self.max_overlaps = max_overlaps
46
+ self.labels = labels
47
+ # Interface for possible user-defined properties
48
+ self._extra_properties = {}
49
+
50
+ @property
51
+ def num_preds(self):
52
+ """int: the number of predictions in this assignment"""
53
+ return len(self.gt_inds)
54
+
55
+ def set_extra_property(self, key, value):
56
+ """Set user-defined new property."""
57
+ assert key not in self.info
58
+ self._extra_properties[key] = value
59
+
60
+ def get_extra_property(self, key):
61
+ """Get user-defined property."""
62
+ return self._extra_properties.get(key, None)
63
+
64
+ @property
65
+ def info(self):
66
+ """dict: a dictionary of info about the object"""
67
+ basic_info = {
68
+ 'num_gts': self.num_gts,
69
+ 'num_preds': self.num_preds,
70
+ 'gt_inds': self.gt_inds,
71
+ 'max_overlaps': self.max_overlaps,
72
+ 'labels': self.labels,
73
+ }
74
+ basic_info.update(self._extra_properties)
75
+ return basic_info
76
+
77
+ def __nice__(self):
78
+ """str: a "nice" summary string describing this assign result"""
79
+ parts = []
80
+ parts.append(f'num_gts={self.num_gts!r}')
81
+ if self.gt_inds is None:
82
+ parts.append(f'gt_inds={self.gt_inds!r}')
83
+ else:
84
+ parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
85
+ if self.max_overlaps is None:
86
+ parts.append(f'max_overlaps={self.max_overlaps!r}')
87
+ else:
88
+ parts.append('max_overlaps.shape='
89
+ f'{tuple(self.max_overlaps.shape)!r}')
90
+ if self.labels is None:
91
+ parts.append(f'labels={self.labels!r}')
92
+ else:
93
+ parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
94
+ return ', '.join(parts)
95
+
96
+ @classmethod
97
+ def random(cls, **kwargs):
98
+ """Create random AssignResult for tests or debugging.
99
+
100
+ Args:
101
+ num_preds: number of predicted boxes
102
+ num_gts: number of true boxes
103
+ p_ignore (float): probability of a predicted box assinged to an
104
+ ignored truth
105
+ p_assigned (float): probability of a predicted box not being
106
+ assigned
107
+ p_use_label (float | bool): with labels or not
108
+ rng (None | int | numpy.random.RandomState): seed or state
109
+
110
+ Returns:
111
+ :obj:`AssignResult`: Randomly generated assign results.
112
+
113
+ Example:
114
+ >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
115
+ >>> self = AssignResult.random()
116
+ >>> print(self.info)
117
+ """
118
+ from mmdet.core.bbox import demodata
119
+ rng = demodata.ensure_rng(kwargs.get('rng', None))
120
+
121
+ num_gts = kwargs.get('num_gts', None)
122
+ num_preds = kwargs.get('num_preds', None)
123
+ p_ignore = kwargs.get('p_ignore', 0.3)
124
+ p_assigned = kwargs.get('p_assigned', 0.7)
125
+ p_use_label = kwargs.get('p_use_label', 0.5)
126
+ num_classes = kwargs.get('p_use_label', 3)
127
+
128
+ if num_gts is None:
129
+ num_gts = rng.randint(0, 8)
130
+ if num_preds is None:
131
+ num_preds = rng.randint(0, 16)
132
+
133
+ if num_gts == 0:
134
+ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
135
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
136
+ if p_use_label is True or p_use_label < rng.rand():
137
+ labels = torch.zeros(num_preds, dtype=torch.int64)
138
+ else:
139
+ labels = None
140
+ else:
141
+ import numpy as np
142
+ # Create an overlap for each predicted box
143
+ max_overlaps = torch.from_numpy(rng.rand(num_preds))
144
+
145
+ # Construct gt_inds for each predicted box
146
+ is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
147
+ # maximum number of assignments constraints
148
+ n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
149
+
150
+ assigned_idxs = np.where(is_assigned)[0]
151
+ rng.shuffle(assigned_idxs)
152
+ assigned_idxs = assigned_idxs[0:n_assigned]
153
+ assigned_idxs.sort()
154
+
155
+ is_assigned[:] = 0
156
+ is_assigned[assigned_idxs] = True
157
+
158
+ is_ignore = torch.from_numpy(
159
+ rng.rand(num_preds) < p_ignore) & is_assigned
160
+
161
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
162
+
163
+ true_idxs = np.arange(num_gts)
164
+ rng.shuffle(true_idxs)
165
+ true_idxs = torch.from_numpy(true_idxs)
166
+ gt_inds[is_assigned] = true_idxs[:n_assigned]
167
+
168
+ gt_inds = torch.from_numpy(
169
+ rng.randint(1, num_gts + 1, size=num_preds))
170
+ gt_inds[is_ignore] = -1
171
+ gt_inds[~is_assigned] = 0
172
+ max_overlaps[~is_assigned] = 0
173
+
174
+ if p_use_label is True or p_use_label < rng.rand():
175
+ if num_classes == 0:
176
+ labels = torch.zeros(num_preds, dtype=torch.int64)
177
+ else:
178
+ labels = torch.from_numpy(
179
+ # remind that we set FG labels to [0, num_class-1]
180
+ # since mmdet v2.0
181
+ # BG cat_id: num_class
182
+ rng.randint(0, num_classes, size=num_preds))
183
+ labels[~is_assigned] = 0
184
+ else:
185
+ labels = None
186
+
187
+ self = cls(num_gts, gt_inds, max_overlaps, labels)
188
+ return self
189
+
190
+ def add_gt_(self, gt_labels):
191
+ """Add ground truth as assigned results.
192
+
193
+ Args:
194
+ gt_labels (torch.Tensor): Labels of gt boxes
195
+ """
196
+ self_inds = torch.arange(
197
+ 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
198
+ self.gt_inds = torch.cat([self_inds, self.gt_inds])
199
+
200
+ self.max_overlaps = torch.cat(
201
+ [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
202
+
203
+ if self.labels is not None:
204
+ self.labels = torch.cat([gt_labels, self.labels])
mmdet/core/bbox/assigners/atss_assigner.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ @BBOX_ASSIGNERS.register_module()
10
+ class ATSSAssigner(BaseAssigner):
11
+ """Assign a corresponding gt bbox or background to each bbox.
12
+
13
+ Each proposals will be assigned with `0` or a positive integer
14
+ indicating the ground truth index.
15
+
16
+ - 0: negative sample, no assigned gt
17
+ - positive integer: positive sample, index (1-based) of assigned gt
18
+
19
+ Args:
20
+ topk (float): number of bbox selected in each level
21
+ """
22
+
23
+ def __init__(self,
24
+ topk,
25
+ iou_calculator=dict(type='BboxOverlaps2D'),
26
+ ignore_iof_thr=-1):
27
+ self.topk = topk
28
+ self.iou_calculator = build_iou_calculator(iou_calculator)
29
+ self.ignore_iof_thr = ignore_iof_thr
30
+
31
+ # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
32
+
33
+ def assign(self,
34
+ bboxes,
35
+ num_level_bboxes,
36
+ gt_bboxes,
37
+ gt_bboxes_ignore=None,
38
+ gt_labels=None):
39
+ """Assign gt to bboxes.
40
+
41
+ The assignment is done in following steps
42
+
43
+ 1. compute iou between all bbox (bbox of all pyramid levels) and gt
44
+ 2. compute center distance between all bbox and gt
45
+ 3. on each pyramid level, for each gt, select k bbox whose center
46
+ are closest to the gt center, so we total select k*l bbox as
47
+ candidates for each gt
48
+ 4. get corresponding iou for the these candidates, and compute the
49
+ mean and std, set mean + std as the iou threshold
50
+ 5. select these candidates whose iou are greater than or equal to
51
+ the threshold as positive
52
+ 6. limit the positive sample's center in gt
53
+
54
+
55
+ Args:
56
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
57
+ num_level_bboxes (List): num of bboxes in each level
58
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
59
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
60
+ labelled as `ignored`, e.g., crowd boxes in COCO.
61
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
62
+
63
+ Returns:
64
+ :obj:`AssignResult`: The assign result.
65
+ """
66
+ INF = 100000000
67
+ bboxes = bboxes[:, :4]
68
+ num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
69
+
70
+ # compute iou between all bbox and gt
71
+ overlaps = self.iou_calculator(bboxes, gt_bboxes)
72
+
73
+ # assign 0 by default
74
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
75
+ 0,
76
+ dtype=torch.long)
77
+
78
+ if num_gt == 0 or num_bboxes == 0:
79
+ # No ground truth or boxes, return empty assignment
80
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
81
+ if num_gt == 0:
82
+ # No truth, assign everything to background
83
+ assigned_gt_inds[:] = 0
84
+ if gt_labels is None:
85
+ assigned_labels = None
86
+ else:
87
+ assigned_labels = overlaps.new_full((num_bboxes, ),
88
+ -1,
89
+ dtype=torch.long)
90
+ return AssignResult(
91
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
92
+
93
+ # compute center distance between all bbox and gt
94
+ gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
95
+ gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
96
+ gt_points = torch.stack((gt_cx, gt_cy), dim=1)
97
+
98
+ bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
99
+ bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
100
+ bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
101
+
102
+ distances = (bboxes_points[:, None, :] -
103
+ gt_points[None, :, :]).pow(2).sum(-1).sqrt()
104
+
105
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
106
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
107
+ ignore_overlaps = self.iou_calculator(
108
+ bboxes, gt_bboxes_ignore, mode='iof')
109
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
110
+ ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
111
+ distances[ignore_idxs, :] = INF
112
+ assigned_gt_inds[ignore_idxs] = -1
113
+
114
+ # Selecting candidates based on the center distance
115
+ candidate_idxs = []
116
+ start_idx = 0
117
+ for level, bboxes_per_level in enumerate(num_level_bboxes):
118
+ # on each pyramid level, for each gt,
119
+ # select k bbox whose center are closest to the gt center
120
+ end_idx = start_idx + bboxes_per_level
121
+ distances_per_level = distances[start_idx:end_idx, :]
122
+ selectable_k = min(self.topk, bboxes_per_level)
123
+ _, topk_idxs_per_level = distances_per_level.topk(
124
+ selectable_k, dim=0, largest=False)
125
+ candidate_idxs.append(topk_idxs_per_level + start_idx)
126
+ start_idx = end_idx
127
+ candidate_idxs = torch.cat(candidate_idxs, dim=0)
128
+
129
+ # get corresponding iou for the these candidates, and compute the
130
+ # mean and std, set mean + std as the iou threshold
131
+ candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
132
+ overlaps_mean_per_gt = candidate_overlaps.mean(0)
133
+ overlaps_std_per_gt = candidate_overlaps.std(0)
134
+ overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
135
+
136
+ is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
137
+
138
+ # limit the positive sample's center in gt
139
+ for gt_idx in range(num_gt):
140
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
141
+ ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
142
+ num_gt, num_bboxes).contiguous().view(-1)
143
+ ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
144
+ num_gt, num_bboxes).contiguous().view(-1)
145
+ candidate_idxs = candidate_idxs.view(-1)
146
+
147
+ # calculate the left, top, right, bottom distance between positive
148
+ # bbox center and gt side
149
+ l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
150
+ t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
151
+ r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
152
+ b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
153
+ is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
154
+ is_pos = is_pos & is_in_gts
155
+
156
+ # if an anchor box is assigned to multiple gts,
157
+ # the one with the highest IoU will be selected.
158
+ overlaps_inf = torch.full_like(overlaps,
159
+ -INF).t().contiguous().view(-1)
160
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
161
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
162
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
163
+
164
+ max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
165
+ assigned_gt_inds[
166
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
167
+
168
+ if gt_labels is not None:
169
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
170
+ pos_inds = torch.nonzero(
171
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
172
+ if pos_inds.numel() > 0:
173
+ assigned_labels[pos_inds] = gt_labels[
174
+ assigned_gt_inds[pos_inds] - 1]
175
+ else:
176
+ assigned_labels = None
177
+ return AssignResult(
178
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
mmdet/core/bbox/assigners/base_assigner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BaseAssigner(metaclass=ABCMeta):
5
+ """Base assigner that assigns boxes to ground truth boxes."""
6
+
7
+ @abstractmethod
8
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
9
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
mmdet/core/bbox/assigners/center_region_assigner.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ def scale_boxes(bboxes, scale):
10
+ """Expand an array of boxes by a given scale.
11
+
12
+ Args:
13
+ bboxes (Tensor): Shape (m, 4)
14
+ scale (float): The scale factor of bboxes
15
+
16
+ Returns:
17
+ (Tensor): Shape (m, 4). Scaled bboxes
18
+ """
19
+ assert bboxes.size(1) == 4
20
+ w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
21
+ h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
22
+ x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
23
+ y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
24
+
25
+ w_half *= scale
26
+ h_half *= scale
27
+
28
+ boxes_scaled = torch.zeros_like(bboxes)
29
+ boxes_scaled[:, 0] = x_c - w_half
30
+ boxes_scaled[:, 2] = x_c + w_half
31
+ boxes_scaled[:, 1] = y_c - h_half
32
+ boxes_scaled[:, 3] = y_c + h_half
33
+ return boxes_scaled
34
+
35
+
36
+ def is_located_in(points, bboxes):
37
+ """Are points located in bboxes.
38
+
39
+ Args:
40
+ points (Tensor): Points, shape: (m, 2).
41
+ bboxes (Tensor): Bounding boxes, shape: (n, 4).
42
+
43
+ Return:
44
+ Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
45
+ """
46
+ assert points.size(1) == 2
47
+ assert bboxes.size(1) == 4
48
+ return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
49
+ (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
50
+ (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
51
+ (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
52
+
53
+
54
+ def bboxes_area(bboxes):
55
+ """Compute the area of an array of bboxes.
56
+
57
+ Args:
58
+ bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
59
+
60
+ Returns:
61
+ Tensor: Area of the bboxes. Shape: (m, )
62
+ """
63
+ assert bboxes.size(1) == 4
64
+ w = (bboxes[:, 2] - bboxes[:, 0])
65
+ h = (bboxes[:, 3] - bboxes[:, 1])
66
+ areas = w * h
67
+ return areas
68
+
69
+
70
+ @BBOX_ASSIGNERS.register_module()
71
+ class CenterRegionAssigner(BaseAssigner):
72
+ """Assign pixels at the center region of a bbox as positive.
73
+
74
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
75
+ indicating the ground truth index.
76
+ - -1: negative samples
77
+ - semi-positive numbers: positive sample, index (0-based) of assigned gt
78
+
79
+ Args:
80
+ pos_scale (float): Threshold within which pixels are
81
+ labelled as positive.
82
+ neg_scale (float): Threshold above which pixels are
83
+ labelled as positive.
84
+ min_pos_iof (float): Minimum iof of a pixel with a gt to be
85
+ labelled as positive. Default: 1e-2
86
+ ignore_gt_scale (float): Threshold within which the pixels
87
+ are ignored when the gt is labelled as shadowed. Default: 0.5
88
+ foreground_dominate (bool): If True, the bbox will be assigned as
89
+ positive when a gt's kernel region overlaps with another's shadowed
90
+ (ignored) region, otherwise it is set as ignored. Default to False.
91
+ """
92
+
93
+ def __init__(self,
94
+ pos_scale,
95
+ neg_scale,
96
+ min_pos_iof=1e-2,
97
+ ignore_gt_scale=0.5,
98
+ foreground_dominate=False,
99
+ iou_calculator=dict(type='BboxOverlaps2D')):
100
+ self.pos_scale = pos_scale
101
+ self.neg_scale = neg_scale
102
+ self.min_pos_iof = min_pos_iof
103
+ self.ignore_gt_scale = ignore_gt_scale
104
+ self.foreground_dominate = foreground_dominate
105
+ self.iou_calculator = build_iou_calculator(iou_calculator)
106
+
107
+ def get_gt_priorities(self, gt_bboxes):
108
+ """Get gt priorities according to their areas.
109
+
110
+ Smaller gt has higher priority.
111
+
112
+ Args:
113
+ gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
114
+
115
+ Returns:
116
+ Tensor: The priority of gts so that gts with larger priority is \
117
+ more likely to be assigned. Shape (k, )
118
+ """
119
+ gt_areas = bboxes_area(gt_bboxes)
120
+ # Rank all gt bbox areas. Smaller objects has larger priority
121
+ _, sort_idx = gt_areas.sort(descending=True)
122
+ sort_idx = sort_idx.argsort()
123
+ return sort_idx
124
+
125
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
126
+ """Assign gt to bboxes.
127
+
128
+ This method assigns gts to every bbox (proposal/anchor), each bbox \
129
+ will be assigned with -1, or a semi-positive number. -1 means \
130
+ negative sample, semi-positive number is the index (0-based) of \
131
+ assigned gt.
132
+
133
+ Args:
134
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
135
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
136
+ gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
137
+ labelled as `ignored`, e.g., crowd boxes in COCO.
138
+ gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
139
+
140
+ Returns:
141
+ :obj:`AssignResult`: The assigned result. Note that \
142
+ shadowed_labels of shape (N, 2) is also added as an \
143
+ `assign_result` attribute. `shadowed_labels` is a tensor \
144
+ composed of N pairs of anchor_ind, class_label], where N \
145
+ is the number of anchors that lie in the outer region of a \
146
+ gt, anchor_ind is the shadowed anchor index and class_label \
147
+ is the shadowed class label.
148
+
149
+ Example:
150
+ >>> self = CenterRegionAssigner(0.2, 0.2)
151
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
152
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
153
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
154
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
155
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
156
+ """
157
+ # There are in total 5 steps in the pixel assignment
158
+ # 1. Find core (the center region, say inner 0.2)
159
+ # and shadow (the relatively ourter part, say inner 0.2-0.5)
160
+ # regions of every gt.
161
+ # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
162
+ # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
163
+ # the image.
164
+ # 3.1. For overlapping objects, the prior bboxes in gt_core is
165
+ # assigned with the object with smallest area
166
+ # 4. Assign prior bboxes with class label according to its gt id.
167
+ # 4.1. Assign -1 to prior bboxes lying in shadowed gts
168
+ # 4.2. Assign positive prior boxes with the corresponding label
169
+ # 5. Find pixels lying in the shadow of an object and assign them with
170
+ # background label, but set the loss weight of its corresponding
171
+ # gt to zero.
172
+ assert bboxes.size(1) == 4, 'bboxes must have size of 4'
173
+ # 1. Find core positive and shadow region of every gt
174
+ gt_core = scale_boxes(gt_bboxes, self.pos_scale)
175
+ gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
176
+
177
+ # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
178
+ bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
179
+ # The center points lie within the gt boxes
180
+ is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
181
+ # Only calculate bbox and gt_core IoF. This enables small prior bboxes
182
+ # to match large gts
183
+ bbox_and_gt_core_overlaps = self.iou_calculator(
184
+ bboxes, gt_core, mode='iof')
185
+ # The center point of effective priors should be within the gt box
186
+ is_bbox_in_gt_core = is_bbox_in_gt & (
187
+ bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
188
+
189
+ is_bbox_in_gt_shadow = (
190
+ self.iou_calculator(bboxes, gt_shadow, mode='iof') >
191
+ self.min_pos_iof)
192
+ # Rule out center effective positive pixels
193
+ is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
194
+
195
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
196
+ if num_gts == 0 or num_bboxes == 0:
197
+ # If no gts exist, assign all pixels to negative
198
+ assigned_gt_ids = \
199
+ is_bbox_in_gt_core.new_zeros((num_bboxes,),
200
+ dtype=torch.long)
201
+ pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
202
+ else:
203
+ # Step 3: assign a one-hot gt id to each pixel, and smaller objects
204
+ # have high priority to assign the pixel.
205
+ sort_idx = self.get_gt_priorities(gt_bboxes)
206
+ assigned_gt_ids, pixels_in_gt_shadow = \
207
+ self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
208
+ is_bbox_in_gt_shadow,
209
+ gt_priority=sort_idx)
210
+
211
+ if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
212
+ # No ground truth or boxes, return empty assignment
213
+ gt_bboxes_ignore = scale_boxes(
214
+ gt_bboxes_ignore, scale=self.ignore_gt_scale)
215
+ is_bbox_in_ignored_gts = is_located_in(bbox_centers,
216
+ gt_bboxes_ignore)
217
+ is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
218
+ assigned_gt_ids[is_bbox_in_ignored_gts] = -1
219
+
220
+ # 4. Assign prior bboxes with class label according to its gt id.
221
+ assigned_labels = None
222
+ shadowed_pixel_labels = None
223
+ if gt_labels is not None:
224
+ # Default assigned label is the background (-1)
225
+ assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
226
+ pos_inds = torch.nonzero(
227
+ assigned_gt_ids > 0, as_tuple=False).squeeze()
228
+ if pos_inds.numel() > 0:
229
+ assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
230
+ - 1]
231
+ # 5. Find pixels lying in the shadow of an object
232
+ shadowed_pixel_labels = pixels_in_gt_shadow.clone()
233
+ if pixels_in_gt_shadow.numel() > 0:
234
+ pixel_idx, gt_idx =\
235
+ pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
236
+ assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
237
+ 'Some pixels are dually assigned to ignore and gt!'
238
+ shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
239
+ override = (
240
+ assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
241
+ if self.foreground_dominate:
242
+ # When a pixel is both positive and shadowed, set it as pos
243
+ shadowed_pixel_labels = shadowed_pixel_labels[~override]
244
+ else:
245
+ # When a pixel is both pos and shadowed, set it as shadowed
246
+ assigned_labels[pixel_idx[override]] = -1
247
+ assigned_gt_ids[pixel_idx[override]] = 0
248
+
249
+ assign_result = AssignResult(
250
+ num_gts, assigned_gt_ids, None, labels=assigned_labels)
251
+ # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
252
+ assign_result.set_extra_property('shadowed_labels',
253
+ shadowed_pixel_labels)
254
+ return assign_result
255
+
256
+ def assign_one_hot_gt_indices(self,
257
+ is_bbox_in_gt_core,
258
+ is_bbox_in_gt_shadow,
259
+ gt_priority=None):
260
+ """Assign only one gt index to each prior box.
261
+
262
+ Gts with large gt_priority are more likely to be assigned.
263
+
264
+ Args:
265
+ is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
266
+ is in the core area of a gt (e.g. 0-0.2).
267
+ Shape: (num_prior, num_gt).
268
+ is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
269
+ center is in the shadowed area of a gt (e.g. 0.2-0.5).
270
+ Shape: (num_prior, num_gt).
271
+ gt_priority (Tensor): Priorities of gts. The gt with a higher
272
+ priority is more likely to be assigned to the bbox when the bbox
273
+ match with multiple gts. Shape: (num_gt, ).
274
+
275
+ Returns:
276
+ tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
277
+
278
+ - assigned_gt_inds: The assigned gt index of each prior bbox \
279
+ (i.e. index from 1 to num_gts). Shape: (num_prior, ).
280
+ - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
281
+ shape (num_ignore, 2) with first column being the \
282
+ shadowed prior bbox indices and the second column the \
283
+ shadowed gt indices (1-based).
284
+ """
285
+ num_bboxes, num_gts = is_bbox_in_gt_core.shape
286
+
287
+ if gt_priority is None:
288
+ gt_priority = torch.arange(
289
+ num_gts, device=is_bbox_in_gt_core.device)
290
+ assert gt_priority.size(0) == num_gts
291
+ # The bigger gt_priority, the more preferable to be assigned
292
+ # The assigned inds are by default 0 (background)
293
+ assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
294
+ dtype=torch.long)
295
+ # Shadowed bboxes are assigned to be background. But the corresponding
296
+ # label is ignored during loss calculation, which is done through
297
+ # shadowed_gt_inds
298
+ shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
299
+ if is_bbox_in_gt_core.sum() == 0: # No gt match
300
+ shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
301
+ return assigned_gt_inds, shadowed_gt_inds
302
+
303
+ # The priority of each prior box and gt pair. If one prior box is
304
+ # matched bo multiple gts. Only the pair with the highest priority
305
+ # is saved
306
+ pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
307
+ -1,
308
+ dtype=torch.long)
309
+
310
+ # Each bbox could match with multiple gts.
311
+ # The following codes deal with this situation
312
+ # Matched bboxes (to any gt). Shape: (num_pos_anchor, )
313
+ inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
314
+ # The matched gt index of each positive bbox. Length >= num_pos_anchor
315
+ # , since one bbox could match multiple gts
316
+ matched_bbox_gt_inds = torch.nonzero(
317
+ is_bbox_in_gt_core, as_tuple=False)[:, 1]
318
+ # Assign priority to each bbox-gt pair.
319
+ pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
320
+ _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
321
+ assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
322
+ # Zero-out the assigned anchor box to filter the shadowed gt indices
323
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
324
+ # Concat the shadowed indices due to overlapping with that out side of
325
+ # effective scale. shape: (total_num_ignore, 2)
326
+ shadowed_gt_inds = torch.cat(
327
+ (shadowed_gt_inds, torch.nonzero(
328
+ is_bbox_in_gt_core, as_tuple=False)),
329
+ dim=0)
330
+ # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
331
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
332
+ # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
333
+ if shadowed_gt_inds.numel() > 0:
334
+ shadowed_gt_inds[:, 1] += 1
335
+ return assigned_gt_inds, shadowed_gt_inds
mmdet/core/bbox/assigners/grid_assigner.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ @BBOX_ASSIGNERS.register_module()
10
+ class GridAssigner(BaseAssigner):
11
+ """Assign a corresponding gt bbox or background to each bbox.
12
+
13
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
14
+ indicating the ground truth index.
15
+
16
+ - -1: don't care
17
+ - 0: negative sample, no assigned gt
18
+ - positive integer: positive sample, index (1-based) of assigned gt
19
+
20
+ Args:
21
+ pos_iou_thr (float): IoU threshold for positive bboxes.
22
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
23
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
24
+ positive bbox. Positive samples can have smaller IoU than
25
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
26
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
27
+ highest overlap with some gt to that gt.
28
+ """
29
+
30
+ def __init__(self,
31
+ pos_iou_thr,
32
+ neg_iou_thr,
33
+ min_pos_iou=.0,
34
+ gt_max_assign_all=True,
35
+ iou_calculator=dict(type='BboxOverlaps2D')):
36
+ self.pos_iou_thr = pos_iou_thr
37
+ self.neg_iou_thr = neg_iou_thr
38
+ self.min_pos_iou = min_pos_iou
39
+ self.gt_max_assign_all = gt_max_assign_all
40
+ self.iou_calculator = build_iou_calculator(iou_calculator)
41
+
42
+ def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
43
+ """Assign gt to bboxes. The process is very much like the max iou
44
+ assigner, except that positive samples are constrained within the cell
45
+ that the gt boxes fell in.
46
+
47
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
48
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
49
+ 0 means negative sample, positive number is the index (1-based) of
50
+ assigned gt.
51
+ The assignment is done in following steps, the order matters.
52
+
53
+ 1. assign every bbox to -1
54
+ 2. assign proposals whose iou with all gts <= neg_iou_thr to 0
55
+ 3. for each bbox within a cell, if the iou with its nearest gt >
56
+ pos_iou_thr and the center of that gt falls inside the cell,
57
+ assign it to that bbox
58
+ 4. for each gt bbox, assign its nearest proposals within the cell the
59
+ gt bbox falls in to itself.
60
+
61
+ Args:
62
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
63
+ box_responsible_flags (Tensor): flag to indicate whether box is
64
+ responsible for prediction, shape(n, )
65
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
66
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
67
+
68
+ Returns:
69
+ :obj:`AssignResult`: The assign result.
70
+ """
71
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
72
+
73
+ # compute iou between all gt and bboxes
74
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
75
+
76
+ # 1. assign -1 by default
77
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
78
+ -1,
79
+ dtype=torch.long)
80
+
81
+ if num_gts == 0 or num_bboxes == 0:
82
+ # No ground truth or boxes, return empty assignment
83
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
84
+ if num_gts == 0:
85
+ # No truth, assign everything to background
86
+ assigned_gt_inds[:] = 0
87
+ if gt_labels is None:
88
+ assigned_labels = None
89
+ else:
90
+ assigned_labels = overlaps.new_full((num_bboxes, ),
91
+ -1,
92
+ dtype=torch.long)
93
+ return AssignResult(
94
+ num_gts,
95
+ assigned_gt_inds,
96
+ max_overlaps,
97
+ labels=assigned_labels)
98
+
99
+ # 2. assign negative: below
100
+ # for each anchor, which gt best overlaps with it
101
+ # for each anchor, the max iou of all gts
102
+ # shape of max_overlaps == argmax_overlaps == num_bboxes
103
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
104
+
105
+ if isinstance(self.neg_iou_thr, float):
106
+ assigned_gt_inds[(max_overlaps >= 0)
107
+ & (max_overlaps <= self.neg_iou_thr)] = 0
108
+ elif isinstance(self.neg_iou_thr, (tuple, list)):
109
+ assert len(self.neg_iou_thr) == 2
110
+ assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
111
+ & (max_overlaps <= self.neg_iou_thr[1])] = 0
112
+
113
+ # 3. assign positive: falls into responsible cell and above
114
+ # positive IOU threshold, the order matters.
115
+ # the prior condition of comparision is to filter out all
116
+ # unrelated anchors, i.e. not box_responsible_flags
117
+ overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
118
+
119
+ # calculate max_overlaps again, but this time we only consider IOUs
120
+ # for anchors responsible for prediction
121
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
122
+
123
+ # for each gt, which anchor best overlaps with it
124
+ # for each gt, the max iou of all proposals
125
+ # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
126
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
127
+
128
+ pos_inds = (max_overlaps >
129
+ self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
130
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
131
+
132
+ # 4. assign positive to max overlapped anchors within responsible cell
133
+ for i in range(num_gts):
134
+ if gt_max_overlaps[i] > self.min_pos_iou:
135
+ if self.gt_max_assign_all:
136
+ max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
137
+ box_responsible_flags.type(torch.bool)
138
+ assigned_gt_inds[max_iou_inds] = i + 1
139
+ elif box_responsible_flags[gt_argmax_overlaps[i]]:
140
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
141
+
142
+ # assign labels of positive anchors
143
+ if gt_labels is not None:
144
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
145
+ pos_inds = torch.nonzero(
146
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
147
+ if pos_inds.numel() > 0:
148
+ assigned_labels[pos_inds] = gt_labels[
149
+ assigned_gt_inds[pos_inds] - 1]
150
+
151
+ else:
152
+ assigned_labels = None
153
+
154
+ return AssignResult(
155
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
mmdet/core/bbox/assigners/hungarian_assigner.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..match_costs import build_match_cost
5
+ from ..transforms import bbox_cxcywh_to_xyxy
6
+ from .assign_result import AssignResult
7
+ from .base_assigner import BaseAssigner
8
+
9
+ try:
10
+ from scipy.optimize import linear_sum_assignment
11
+ except ImportError:
12
+ linear_sum_assignment = None
13
+
14
+
15
+ @BBOX_ASSIGNERS.register_module()
16
+ class HungarianAssigner(BaseAssigner):
17
+ """Computes one-to-one matching between predictions and ground truth.
18
+
19
+ This class computes an assignment between the targets and the predictions
20
+ based on the costs. The costs are weighted sum of three components:
21
+ classification cost, regression L1 cost and regression iou cost. The
22
+ targets don't include the no_object, so generally there are more
23
+ predictions than targets. After the one-to-one matching, the un-matched
24
+ are treated as backgrounds. Thus each query prediction will be assigned
25
+ with `0` or a positive integer indicating the ground truth index:
26
+
27
+ - 0: negative sample, no assigned gt
28
+ - positive integer: positive sample, index (1-based) of assigned gt
29
+
30
+ Args:
31
+ cls_weight (int | float, optional): The scale factor for classification
32
+ cost. Default 1.0.
33
+ bbox_weight (int | float, optional): The scale factor for regression
34
+ L1 cost. Default 1.0.
35
+ iou_weight (int | float, optional): The scale factor for regression
36
+ iou cost. Default 1.0.
37
+ iou_calculator (dict | optional): The config for the iou calculation.
38
+ Default type `BboxOverlaps2D`.
39
+ iou_mode (str | optional): "iou" (intersection over union), "iof"
40
+ (intersection over foreground), or "giou" (generalized
41
+ intersection over union). Default "giou".
42
+ """
43
+
44
+ def __init__(self,
45
+ cls_cost=dict(type='ClassificationCost', weight=1.),
46
+ reg_cost=dict(type='BBoxL1Cost', weight=1.0),
47
+ iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
48
+ self.cls_cost = build_match_cost(cls_cost)
49
+ self.reg_cost = build_match_cost(reg_cost)
50
+ self.iou_cost = build_match_cost(iou_cost)
51
+
52
+ def assign(self,
53
+ bbox_pred,
54
+ cls_pred,
55
+ gt_bboxes,
56
+ gt_labels,
57
+ img_meta,
58
+ gt_bboxes_ignore=None,
59
+ eps=1e-7):
60
+ """Computes one-to-one matching based on the weighted costs.
61
+
62
+ This method assign each query prediction to a ground truth or
63
+ background. The `assigned_gt_inds` with -1 means don't care,
64
+ 0 means negative sample, and positive number is the index (1-based)
65
+ of assigned gt.
66
+ The assignment is done in the following steps, the order matters.
67
+
68
+ 1. assign every prediction to -1
69
+ 2. compute the weighted costs
70
+ 3. do Hungarian matching on CPU based on the costs
71
+ 4. assign all to 0 (background) first, then for each matched pair
72
+ between predictions and gts, treat this prediction as foreground
73
+ and assign the corresponding gt index (plus 1) to it.
74
+
75
+ Args:
76
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
77
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
78
+ [num_query, 4].
79
+ cls_pred (Tensor): Predicted classification logits, shape
80
+ [num_query, num_class].
81
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
82
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
83
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
84
+ img_meta (dict): Meta information for current image.
85
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
86
+ labelled as `ignored`. Default None.
87
+ eps (int | float, optional): A value added to the denominator for
88
+ numerical stability. Default 1e-7.
89
+
90
+ Returns:
91
+ :obj:`AssignResult`: The assigned result.
92
+ """
93
+ assert gt_bboxes_ignore is None, \
94
+ 'Only case when gt_bboxes_ignore is None is supported.'
95
+ num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
96
+
97
+ # 1. assign -1 by default
98
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
99
+ -1,
100
+ dtype=torch.long)
101
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
102
+ -1,
103
+ dtype=torch.long)
104
+ if num_gts == 0 or num_bboxes == 0:
105
+ # No ground truth or boxes, return empty assignment
106
+ if num_gts == 0:
107
+ # No ground truth, assign all to background
108
+ assigned_gt_inds[:] = 0
109
+ return AssignResult(
110
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
111
+ img_h, img_w, _ = img_meta['img_shape']
112
+ factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
113
+ img_h]).unsqueeze(0)
114
+
115
+ # 2. compute the weighted costs
116
+ # classification and bboxcost.
117
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
118
+ # regression L1 cost
119
+ normalize_gt_bboxes = gt_bboxes / factor
120
+ reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
121
+ # regression iou cost, defaultly giou is used in official DETR.
122
+ bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
123
+ iou_cost = self.iou_cost(bboxes, gt_bboxes)
124
+ # weighted sum of above three costs
125
+ cost = cls_cost + reg_cost + iou_cost
126
+
127
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
128
+ cost = cost.detach().cpu()
129
+ if linear_sum_assignment is None:
130
+ raise ImportError('Please run "pip install scipy" '
131
+ 'to install scipy first.')
132
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
133
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
134
+ bbox_pred.device)
135
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
136
+ bbox_pred.device)
137
+
138
+ # 4. assign backgrounds and foregrounds
139
+ # assign all indices to backgrounds first
140
+ assigned_gt_inds[:] = 0
141
+ # assign foregrounds based on matching results
142
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
143
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
144
+ return AssignResult(
145
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
mmdet/core/bbox/assigners/max_iou_assigner.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ @BBOX_ASSIGNERS.register_module()
10
+ class MaxIoUAssigner(BaseAssigner):
11
+ """Assign a corresponding gt bbox or background to each bbox.
12
+
13
+ Each proposals will be assigned with `-1`, or a semi-positive integer
14
+ indicating the ground truth index.
15
+
16
+ - -1: negative sample, no assigned gt
17
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
18
+
19
+ Args:
20
+ pos_iou_thr (float): IoU threshold for positive bboxes.
21
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
22
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
23
+ positive bbox. Positive samples can have smaller IoU than
24
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
25
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
26
+ highest overlap with some gt to that gt.
27
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
28
+ `gt_bboxes_ignore` is specified). Negative values mean not
29
+ ignoring any bboxes.
30
+ ignore_wrt_candidates (bool): Whether to compute the iof between
31
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
32
+ match_low_quality (bool): Whether to allow low quality matches. This is
33
+ usually allowed for RPN and single stage detectors, but not allowed
34
+ in the second stage. Details are demonstrated in Step 4.
35
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
36
+ assign. When the number of gt is above this threshold, will assign
37
+ on CPU device. Negative values mean not assign on CPU.
38
+ """
39
+
40
+ def __init__(self,
41
+ pos_iou_thr,
42
+ neg_iou_thr,
43
+ min_pos_iou=.0,
44
+ gt_max_assign_all=True,
45
+ ignore_iof_thr=-1,
46
+ ignore_wrt_candidates=True,
47
+ match_low_quality=True,
48
+ gpu_assign_thr=-1,
49
+ iou_calculator=dict(type='BboxOverlaps2D')):
50
+ self.pos_iou_thr = pos_iou_thr
51
+ self.neg_iou_thr = neg_iou_thr
52
+ self.min_pos_iou = min_pos_iou
53
+ self.gt_max_assign_all = gt_max_assign_all
54
+ self.ignore_iof_thr = ignore_iof_thr
55
+ self.ignore_wrt_candidates = ignore_wrt_candidates
56
+ self.gpu_assign_thr = gpu_assign_thr
57
+ self.match_low_quality = match_low_quality
58
+ self.iou_calculator = build_iou_calculator(iou_calculator)
59
+
60
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
61
+ """Assign gt to bboxes.
62
+
63
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
64
+ will be assigned with -1, or a semi-positive number. -1 means negative
65
+ sample, semi-positive number is the index (0-based) of assigned gt.
66
+ The assignment is done in following steps, the order matters.
67
+
68
+ 1. assign every bbox to the background
69
+ 2. assign proposals whose iou with all gts < neg_iou_thr to 0
70
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
71
+ assign it to that bbox
72
+ 4. for each gt bbox, assign its nearest proposals (may be more than
73
+ one) to itself
74
+
75
+ Args:
76
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
77
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
78
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
79
+ labelled as `ignored`, e.g., crowd boxes in COCO.
80
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
81
+
82
+ Returns:
83
+ :obj:`AssignResult`: The assign result.
84
+
85
+ Example:
86
+ >>> self = MaxIoUAssigner(0.5, 0.5)
87
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
88
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
89
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
90
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
91
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
92
+ """
93
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
94
+ gt_bboxes.shape[0] > self.gpu_assign_thr) else False
95
+ # compute overlap and assign gt on CPU when number of GT is large
96
+ if assign_on_cpu:
97
+ device = bboxes.device
98
+ bboxes = bboxes.cpu()
99
+ gt_bboxes = gt_bboxes.cpu()
100
+ if gt_bboxes_ignore is not None:
101
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
102
+ if gt_labels is not None:
103
+ gt_labels = gt_labels.cpu()
104
+
105
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
106
+
107
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
108
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
109
+ if self.ignore_wrt_candidates:
110
+ ignore_overlaps = self.iou_calculator(
111
+ bboxes, gt_bboxes_ignore, mode='iof')
112
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
113
+ else:
114
+ ignore_overlaps = self.iou_calculator(
115
+ gt_bboxes_ignore, bboxes, mode='iof')
116
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
117
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
118
+
119
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
120
+ if assign_on_cpu:
121
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
122
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
123
+ if assign_result.labels is not None:
124
+ assign_result.labels = assign_result.labels.to(device)
125
+ return assign_result
126
+
127
+ def assign_wrt_overlaps(self, overlaps, gt_labels=None):
128
+ """Assign w.r.t. the overlaps of bboxes with gts.
129
+
130
+ Args:
131
+ overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
132
+ shape(k, n).
133
+ gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
134
+
135
+ Returns:
136
+ :obj:`AssignResult`: The assign result.
137
+ """
138
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
139
+
140
+ # 1. assign -1 by default
141
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
142
+ -1,
143
+ dtype=torch.long)
144
+
145
+ if num_gts == 0 or num_bboxes == 0:
146
+ # No ground truth or boxes, return empty assignment
147
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
148
+ if num_gts == 0:
149
+ # No truth, assign everything to background
150
+ assigned_gt_inds[:] = 0
151
+ if gt_labels is None:
152
+ assigned_labels = None
153
+ else:
154
+ assigned_labels = overlaps.new_full((num_bboxes, ),
155
+ -1,
156
+ dtype=torch.long)
157
+ return AssignResult(
158
+ num_gts,
159
+ assigned_gt_inds,
160
+ max_overlaps,
161
+ labels=assigned_labels)
162
+
163
+ # for each anchor, which gt best overlaps with it
164
+ # for each anchor, the max iou of all gts
165
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
166
+ # for each gt, which anchor best overlaps with it
167
+ # for each gt, the max iou of all proposals
168
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
169
+
170
+ # 2. assign negative: below
171
+ # the negative inds are set to be 0
172
+ if isinstance(self.neg_iou_thr, float):
173
+ assigned_gt_inds[(max_overlaps >= 0)
174
+ & (max_overlaps < self.neg_iou_thr)] = 0
175
+ elif isinstance(self.neg_iou_thr, tuple):
176
+ assert len(self.neg_iou_thr) == 2
177
+ assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
178
+ & (max_overlaps < self.neg_iou_thr[1])] = 0
179
+
180
+ # 3. assign positive: above positive IoU threshold
181
+ pos_inds = max_overlaps >= self.pos_iou_thr
182
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
183
+
184
+ if self.match_low_quality:
185
+ # Low-quality matching will overwrite the assigned_gt_inds assigned
186
+ # in Step 3. Thus, the assigned gt might not be the best one for
187
+ # prediction.
188
+ # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
189
+ # bbox 1 will be assigned as the best target for bbox A in step 3.
190
+ # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
191
+ # assigned_gt_inds will be overwritten to be bbox B.
192
+ # This might be the reason that it is not used in ROI Heads.
193
+ for i in range(num_gts):
194
+ if gt_max_overlaps[i] >= self.min_pos_iou:
195
+ if self.gt_max_assign_all:
196
+ max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
197
+ assigned_gt_inds[max_iou_inds] = i + 1
198
+ else:
199
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
200
+
201
+ if gt_labels is not None:
202
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
203
+ pos_inds = torch.nonzero(
204
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
205
+ if pos_inds.numel() > 0:
206
+ assigned_labels[pos_inds] = gt_labels[
207
+ assigned_gt_inds[pos_inds] - 1]
208
+ else:
209
+ assigned_labels = None
210
+
211
+ return AssignResult(
212
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
mmdet/core/bbox/assigners/point_assigner.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from .assign_result import AssignResult
5
+ from .base_assigner import BaseAssigner
6
+
7
+
8
+ @BBOX_ASSIGNERS.register_module()
9
+ class PointAssigner(BaseAssigner):
10
+ """Assign a corresponding gt bbox or background to each point.
11
+
12
+ Each proposals will be assigned with `0`, or a positive integer
13
+ indicating the ground truth index.
14
+
15
+ - 0: negative sample, no assigned gt
16
+ - positive integer: positive sample, index (1-based) of assigned gt
17
+ """
18
+
19
+ def __init__(self, scale=4, pos_num=3):
20
+ self.scale = scale
21
+ self.pos_num = pos_num
22
+
23
+ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
24
+ """Assign gt to points.
25
+
26
+ This method assign a gt bbox to every points set, each points set
27
+ will be assigned with the background_label (-1), or a label number.
28
+ -1 is background, and semi-positive number is the index (0-based) of
29
+ assigned gt.
30
+ The assignment is done in following steps, the order matters.
31
+
32
+ 1. assign every points to the background_label (-1)
33
+ 2. A point is assigned to some gt bbox if
34
+ (i) the point is within the k closest points to the gt bbox
35
+ (ii) the distance between this point and the gt is smaller than
36
+ other gt bboxes
37
+
38
+ Args:
39
+ points (Tensor): points to be assigned, shape(n, 3) while last
40
+ dimension stands for (x, y, stride).
41
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
42
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
43
+ labelled as `ignored`, e.g., crowd boxes in COCO.
44
+ NOTE: currently unused.
45
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
46
+
47
+ Returns:
48
+ :obj:`AssignResult`: The assign result.
49
+ """
50
+ num_points = points.shape[0]
51
+ num_gts = gt_bboxes.shape[0]
52
+
53
+ if num_gts == 0 or num_points == 0:
54
+ # If no truth assign everything to the background
55
+ assigned_gt_inds = points.new_full((num_points, ),
56
+ 0,
57
+ dtype=torch.long)
58
+ if gt_labels is None:
59
+ assigned_labels = None
60
+ else:
61
+ assigned_labels = points.new_full((num_points, ),
62
+ -1,
63
+ dtype=torch.long)
64
+ return AssignResult(
65
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
66
+
67
+ points_xy = points[:, :2]
68
+ points_stride = points[:, 2]
69
+ points_lvl = torch.log2(
70
+ points_stride).int() # [3...,4...,5...,6...,7...]
71
+ lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
72
+
73
+ # assign gt box
74
+ gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
75
+ gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
76
+ scale = self.scale
77
+ gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
78
+ torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
79
+ gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
80
+
81
+ # stores the assigned gt index of each point
82
+ assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
83
+ # stores the assigned gt dist (to this point) of each point
84
+ assigned_gt_dist = points.new_full((num_points, ), float('inf'))
85
+ points_range = torch.arange(points.shape[0])
86
+
87
+ for idx in range(num_gts):
88
+ gt_lvl = gt_bboxes_lvl[idx]
89
+ # get the index of points in this level
90
+ lvl_idx = gt_lvl == points_lvl
91
+ points_index = points_range[lvl_idx]
92
+ # get the points in this level
93
+ lvl_points = points_xy[lvl_idx, :]
94
+ # get the center point of gt
95
+ gt_point = gt_bboxes_xy[[idx], :]
96
+ # get width and height of gt
97
+ gt_wh = gt_bboxes_wh[[idx], :]
98
+ # compute the distance between gt center and
99
+ # all points in this level
100
+ points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
101
+ # find the nearest k points to gt center in this level
102
+ min_dist, min_dist_index = torch.topk(
103
+ points_gt_dist, self.pos_num, largest=False)
104
+ # the index of nearest k points to gt center in this level
105
+ min_dist_points_index = points_index[min_dist_index]
106
+ # The less_than_recorded_index stores the index
107
+ # of min_dist that is less then the assigned_gt_dist. Where
108
+ # assigned_gt_dist stores the dist from previous assigned gt
109
+ # (if exist) to each point.
110
+ less_than_recorded_index = min_dist < assigned_gt_dist[
111
+ min_dist_points_index]
112
+ # The min_dist_points_index stores the index of points satisfy:
113
+ # (1) it is k nearest to current gt center in this level.
114
+ # (2) it is closer to current gt center than other gt center.
115
+ min_dist_points_index = min_dist_points_index[
116
+ less_than_recorded_index]
117
+ # assign the result
118
+ assigned_gt_inds[min_dist_points_index] = idx + 1
119
+ assigned_gt_dist[min_dist_points_index] = min_dist[
120
+ less_than_recorded_index]
121
+
122
+ if gt_labels is not None:
123
+ assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
124
+ pos_inds = torch.nonzero(
125
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
126
+ if pos_inds.numel() > 0:
127
+ assigned_labels[pos_inds] = gt_labels[
128
+ assigned_gt_inds[pos_inds] - 1]
129
+ else:
130
+ assigned_labels = None
131
+
132
+ return AssignResult(
133
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
mmdet/core/bbox/assigners/region_assigner.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmdet.core import anchor_inside_flags
4
+ from ..builder import BBOX_ASSIGNERS
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ def calc_region(bbox, ratio, stride, featmap_size=None):
10
+ """Calculate region of the box defined by the ratio, the ratio is from the
11
+ center of the box to every edge."""
12
+ # project bbox on the feature
13
+ f_bbox = bbox / stride
14
+ x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
15
+ y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
16
+ x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
17
+ y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
18
+ if featmap_size is not None:
19
+ x1 = x1.clamp(min=0, max=featmap_size[1])
20
+ y1 = y1.clamp(min=0, max=featmap_size[0])
21
+ x2 = x2.clamp(min=0, max=featmap_size[1])
22
+ y2 = y2.clamp(min=0, max=featmap_size[0])
23
+ return (x1, y1, x2, y2)
24
+
25
+
26
+ def anchor_ctr_inside_region_flags(anchors, stride, region):
27
+ """Get the flag indicate whether anchor centers are inside regions."""
28
+ x1, y1, x2, y2 = region
29
+ f_anchors = anchors / stride
30
+ x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
31
+ y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
32
+ flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
33
+ return flags
34
+
35
+
36
+ @BBOX_ASSIGNERS.register_module()
37
+ class RegionAssigner(BaseAssigner):
38
+ """Assign a corresponding gt bbox or background to each bbox.
39
+
40
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
41
+ indicating the ground truth index.
42
+
43
+ - -1: don't care
44
+ - 0: negative sample, no assigned gt
45
+ - positive integer: positive sample, index (1-based) of assigned gt
46
+
47
+ Args:
48
+ center_ratio: ratio of the region in the center of the bbox to
49
+ define positive sample.
50
+ ignore_ratio: ratio of the region to define ignore samples.
51
+ """
52
+
53
+ def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
54
+ self.center_ratio = center_ratio
55
+ self.ignore_ratio = ignore_ratio
56
+
57
+ def assign(self,
58
+ mlvl_anchors,
59
+ mlvl_valid_flags,
60
+ gt_bboxes,
61
+ img_meta,
62
+ featmap_sizes,
63
+ anchor_scale,
64
+ anchor_strides,
65
+ gt_bboxes_ignore=None,
66
+ gt_labels=None,
67
+ allowed_border=0):
68
+ """Assign gt to anchors.
69
+
70
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
71
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
72
+ 0 means negative sample, positive number is the index (1-based) of
73
+ assigned gt.
74
+ The assignment is done in following steps, the order matters.
75
+
76
+ 1. Assign every anchor to 0 (negative)
77
+ For each gt_bboxes:
78
+ 2. Compute ignore flags based on ignore_region then
79
+ assign -1 to anchors w.r.t. ignore flags
80
+ 3. Compute pos flags based on center_region then
81
+ assign gt_bboxes to anchors w.r.t. pos flags
82
+ 4. Compute ignore flags based on adjacent anchor lvl then
83
+ assign -1 to anchors w.r.t. ignore flags
84
+ 5. Assign anchor outside of image to -1
85
+
86
+ Args:
87
+ mlvl_anchors (list[Tensor]): Multi level anchors.
88
+ mlvl_valid_flags (list[Tensor]): Multi level valid flags.
89
+ gt_bboxes (Tensor): Ground truth bboxes of image
90
+ img_meta (dict): Meta info of image.
91
+ featmap_sizes (list[Tensor]): Feature mapsize each level
92
+ anchor_scale (int): Scale of the anchor.
93
+ anchor_strides (list[int]): Stride of the anchor.
94
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
95
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
96
+ labelled as `ignored`, e.g., crowd boxes in COCO.
97
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
98
+ allowed_border (int, optional): The border to allow the valid
99
+ anchor. Defaults to 0.
100
+
101
+ Returns:
102
+ :obj:`AssignResult`: The assign result.
103
+ """
104
+ if gt_bboxes_ignore is not None:
105
+ raise NotImplementedError
106
+
107
+ num_gts = gt_bboxes.shape[0]
108
+ num_bboxes = sum(x.shape[0] for x in mlvl_anchors)
109
+
110
+ if num_gts == 0 or num_bboxes == 0:
111
+ # No ground truth or boxes, return empty assignment
112
+ max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
113
+ assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
114
+ dtype=torch.long)
115
+ if gt_labels is None:
116
+ assigned_labels = None
117
+ else:
118
+ assigned_labels = gt_bboxes.new_full((num_bboxes, ),
119
+ -1,
120
+ dtype=torch.long)
121
+ return AssignResult(
122
+ num_gts,
123
+ assigned_gt_inds,
124
+ max_overlaps,
125
+ labels=assigned_labels)
126
+
127
+ num_lvls = len(mlvl_anchors)
128
+ r1 = (1 - self.center_ratio) / 2
129
+ r2 = (1 - self.ignore_ratio) / 2
130
+
131
+ scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
132
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
133
+ min_anchor_size = scale.new_full(
134
+ (1, ), float(anchor_scale * anchor_strides[0]))
135
+ target_lvls = torch.floor(
136
+ torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
137
+ target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
138
+
139
+ # 1. assign 0 (negative) by default
140
+ mlvl_assigned_gt_inds = []
141
+ mlvl_ignore_flags = []
142
+ for lvl in range(num_lvls):
143
+ h, w = featmap_sizes[lvl]
144
+ assert h * w == mlvl_anchors[lvl].shape[0]
145
+ assigned_gt_inds = gt_bboxes.new_full((h * w, ),
146
+ 0,
147
+ dtype=torch.long)
148
+ ignore_flags = torch.zeros_like(assigned_gt_inds)
149
+ mlvl_assigned_gt_inds.append(assigned_gt_inds)
150
+ mlvl_ignore_flags.append(ignore_flags)
151
+
152
+ for gt_id in range(num_gts):
153
+ lvl = target_lvls[gt_id].item()
154
+ featmap_size = featmap_sizes[lvl]
155
+ stride = anchor_strides[lvl]
156
+ anchors = mlvl_anchors[lvl]
157
+ gt_bbox = gt_bboxes[gt_id, :4]
158
+
159
+ # Compute regions
160
+ ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
161
+ ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
162
+
163
+ # 2. Assign -1 to ignore flags
164
+ ignore_flags = anchor_ctr_inside_region_flags(
165
+ anchors, stride, ignore_region)
166
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
167
+
168
+ # 3. Assign gt_bboxes to pos flags
169
+ pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
170
+ ctr_region)
171
+ mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
172
+
173
+ # 4. Assign -1 to ignore adjacent lvl
174
+ if lvl > 0:
175
+ d_lvl = lvl - 1
176
+ d_anchors = mlvl_anchors[d_lvl]
177
+ d_featmap_size = featmap_sizes[d_lvl]
178
+ d_stride = anchor_strides[d_lvl]
179
+ d_ignore_region = calc_region(gt_bbox, r2, d_stride,
180
+ d_featmap_size)
181
+ ignore_flags = anchor_ctr_inside_region_flags(
182
+ d_anchors, d_stride, d_ignore_region)
183
+ mlvl_ignore_flags[d_lvl][ignore_flags] = 1
184
+ if lvl < num_lvls - 1:
185
+ u_lvl = lvl + 1
186
+ u_anchors = mlvl_anchors[u_lvl]
187
+ u_featmap_size = featmap_sizes[u_lvl]
188
+ u_stride = anchor_strides[u_lvl]
189
+ u_ignore_region = calc_region(gt_bbox, r2, u_stride,
190
+ u_featmap_size)
191
+ ignore_flags = anchor_ctr_inside_region_flags(
192
+ u_anchors, u_stride, u_ignore_region)
193
+ mlvl_ignore_flags[u_lvl][ignore_flags] = 1
194
+
195
+ # 4. (cont.) Assign -1 to ignore adjacent lvl
196
+ for lvl in range(num_lvls):
197
+ ignore_flags = mlvl_ignore_flags[lvl]
198
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
199
+
200
+ # 5. Assign -1 to anchor outside of image
201
+ flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
202
+ flat_anchors = torch.cat(mlvl_anchors)
203
+ flat_valid_flags = torch.cat(mlvl_valid_flags)
204
+ assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
205
+ flat_valid_flags.shape[0])
206
+ inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
207
+ img_meta['img_shape'],
208
+ allowed_border)
209
+ outside_flags = ~inside_flags
210
+ flat_assigned_gt_inds[outside_flags] = -1
211
+
212
+ if gt_labels is not None:
213
+ assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
214
+ pos_flags = assigned_gt_inds > 0
215
+ assigned_labels[pos_flags] = gt_labels[
216
+ flat_assigned_gt_inds[pos_flags] - 1]
217
+ else:
218
+ assigned_labels = None
219
+
220
+ return AssignResult(
221
+ num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
mmdet/core/bbox/builder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.utils import Registry, build_from_cfg
2
+
3
+ BBOX_ASSIGNERS = Registry('bbox_assigner')
4
+ BBOX_SAMPLERS = Registry('bbox_sampler')
5
+ BBOX_CODERS = Registry('bbox_coder')
6
+
7
+
8
+ def build_assigner(cfg, **default_args):
9
+ """Builder of box assigner."""
10
+ return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
11
+
12
+
13
+ def build_sampler(cfg, **default_args):
14
+ """Builder of box sampler."""
15
+ return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
16
+
17
+
18
+ def build_bbox_coder(cfg, **default_args):
19
+ """Builder of box coder."""
20
+ return build_from_cfg(cfg, BBOX_CODERS, default_args)
mmdet/core/bbox/coder/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_bbox_coder import BaseBBoxCoder
2
+ from .bucketing_bbox_coder import BucketingBBoxCoder
3
+ from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
4
+ from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
5
+ from .pseudo_bbox_coder import PseudoBBoxCoder
6
+ from .tblr_bbox_coder import TBLRBBoxCoder
7
+ from .yolo_bbox_coder import YOLOBBoxCoder
8
+
9
+ __all__ = [
10
+ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
11
+ 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
12
+ 'BucketingBBoxCoder'
13
+ ]
mmdet/core/bbox/coder/base_bbox_coder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BaseBBoxCoder(metaclass=ABCMeta):
5
+ """Base bounding box coder."""
6
+
7
+ def __init__(self, **kwargs):
8
+ pass
9
+
10
+ @abstractmethod
11
+ def encode(self, bboxes, gt_bboxes):
12
+ """Encode deltas between bboxes and ground truth boxes."""
13
+
14
+ @abstractmethod
15
+ def decode(self, bboxes, bboxes_pred):
16
+ """Decode the predicted bboxes according to prediction and base
17
+ boxes."""
mmdet/core/bbox/coder/bucketing_bbox_coder.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ..builder import BBOX_CODERS
7
+ from ..transforms import bbox_rescale
8
+ from .base_bbox_coder import BaseBBoxCoder
9
+
10
+
11
+ @BBOX_CODERS.register_module()
12
+ class BucketingBBoxCoder(BaseBBoxCoder):
13
+ """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
14
+
15
+ Boundary Localization with Bucketing and Bucketing Guided Rescoring
16
+ are implemented here.
17
+
18
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
19
+
20
+ Args:
21
+ num_buckets (int): Number of buckets.
22
+ scale_factor (int): Scale factor of proposals to generate buckets.
23
+ offset_topk (int): Topk buckets are used to generate
24
+ bucket fine regression targets. Defaults to 2.
25
+ offset_upperbound (float): Offset upperbound to generate
26
+ bucket fine regression targets.
27
+ To avoid too large offset displacements. Defaults to 1.0.
28
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
29
+ Defaults to True.
30
+ clip_border (bool, optional): Whether clip the objects outside the
31
+ border of the image. Defaults to True.
32
+ """
33
+
34
+ def __init__(self,
35
+ num_buckets,
36
+ scale_factor,
37
+ offset_topk=2,
38
+ offset_upperbound=1.0,
39
+ cls_ignore_neighbor=True,
40
+ clip_border=True):
41
+ super(BucketingBBoxCoder, self).__init__()
42
+ self.num_buckets = num_buckets
43
+ self.scale_factor = scale_factor
44
+ self.offset_topk = offset_topk
45
+ self.offset_upperbound = offset_upperbound
46
+ self.cls_ignore_neighbor = cls_ignore_neighbor
47
+ self.clip_border = clip_border
48
+
49
+ def encode(self, bboxes, gt_bboxes):
50
+ """Get bucketing estimation and fine regression targets during
51
+ training.
52
+
53
+ Args:
54
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
55
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
56
+ ground truth boxes.
57
+
58
+ Returns:
59
+ encoded_bboxes(tuple[Tensor]): bucketing estimation
60
+ and fine regression targets and weights
61
+ """
62
+
63
+ assert bboxes.size(0) == gt_bboxes.size(0)
64
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
65
+ encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
66
+ self.scale_factor, self.offset_topk,
67
+ self.offset_upperbound,
68
+ self.cls_ignore_neighbor)
69
+ return encoded_bboxes
70
+
71
+ def decode(self, bboxes, pred_bboxes, max_shape=None):
72
+ """Apply transformation `pred_bboxes` to `boxes`.
73
+ Args:
74
+ boxes (torch.Tensor): Basic boxes.
75
+ pred_bboxes (torch.Tensor): Predictions for bucketing estimation
76
+ and fine regression
77
+ max_shape (tuple[int], optional): Maximum shape of boxes.
78
+ Defaults to None.
79
+
80
+ Returns:
81
+ torch.Tensor: Decoded boxes.
82
+ """
83
+ assert len(pred_bboxes) == 2
84
+ cls_preds, offset_preds = pred_bboxes
85
+ assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
86
+ 0) == bboxes.size(0)
87
+ decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
88
+ self.num_buckets, self.scale_factor,
89
+ max_shape, self.clip_border)
90
+
91
+ return decoded_bboxes
92
+
93
+
94
+ @mmcv.jit(coderize=True)
95
+ def generat_buckets(proposals, num_buckets, scale_factor=1.0):
96
+ """Generate buckets w.r.t bucket number and scale factor of proposals.
97
+
98
+ Args:
99
+ proposals (Tensor): Shape (n, 4)
100
+ num_buckets (int): Number of buckets.
101
+ scale_factor (float): Scale factor to rescale proposals.
102
+
103
+ Returns:
104
+ tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
105
+ t_buckets, d_buckets)
106
+
107
+ - bucket_w: Width of buckets on x-axis. Shape (n, ).
108
+ - bucket_h: Height of buckets on y-axis. Shape (n, ).
109
+ - l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
110
+ - r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
111
+ - t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
112
+ - d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
113
+ """
114
+ proposals = bbox_rescale(proposals, scale_factor)
115
+
116
+ # number of buckets in each side
117
+ side_num = int(np.ceil(num_buckets / 2.0))
118
+ pw = proposals[..., 2] - proposals[..., 0]
119
+ ph = proposals[..., 3] - proposals[..., 1]
120
+ px1 = proposals[..., 0]
121
+ py1 = proposals[..., 1]
122
+ px2 = proposals[..., 2]
123
+ py2 = proposals[..., 3]
124
+
125
+ bucket_w = pw / num_buckets
126
+ bucket_h = ph / num_buckets
127
+
128
+ # left buckets
129
+ l_buckets = px1[:, None] + (0.5 + torch.arange(
130
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
131
+ # right buckets
132
+ r_buckets = px2[:, None] - (0.5 + torch.arange(
133
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
134
+ # top buckets
135
+ t_buckets = py1[:, None] + (0.5 + torch.arange(
136
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
137
+ # down buckets
138
+ d_buckets = py2[:, None] - (0.5 + torch.arange(
139
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
140
+ return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
141
+
142
+
143
+ @mmcv.jit(coderize=True)
144
+ def bbox2bucket(proposals,
145
+ gt,
146
+ num_buckets,
147
+ scale_factor,
148
+ offset_topk=2,
149
+ offset_upperbound=1.0,
150
+ cls_ignore_neighbor=True):
151
+ """Generate buckets estimation and fine regression targets.
152
+
153
+ Args:
154
+ proposals (Tensor): Shape (n, 4)
155
+ gt (Tensor): Shape (n, 4)
156
+ num_buckets (int): Number of buckets.
157
+ scale_factor (float): Scale factor to rescale proposals.
158
+ offset_topk (int): Topk buckets are used to generate
159
+ bucket fine regression targets. Defaults to 2.
160
+ offset_upperbound (float): Offset allowance to generate
161
+ bucket fine regression targets.
162
+ To avoid too large offset displacements. Defaults to 1.0.
163
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
164
+ Defaults to True.
165
+
166
+ Returns:
167
+ tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
168
+
169
+ - offsets: Fine regression targets. \
170
+ Shape (n, num_buckets*2).
171
+ - offsets_weights: Fine regression weights. \
172
+ Shape (n, num_buckets*2).
173
+ - bucket_labels: Bucketing estimation labels. \
174
+ Shape (n, num_buckets*2).
175
+ - cls_weights: Bucketing estimation weights. \
176
+ Shape (n, num_buckets*2).
177
+ """
178
+ assert proposals.size() == gt.size()
179
+
180
+ # generate buckets
181
+ proposals = proposals.float()
182
+ gt = gt.float()
183
+ (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
184
+ d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
185
+
186
+ gx1 = gt[..., 0]
187
+ gy1 = gt[..., 1]
188
+ gx2 = gt[..., 2]
189
+ gy2 = gt[..., 3]
190
+
191
+ # generate offset targets and weights
192
+ # offsets from buckets to gts
193
+ l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
194
+ r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
195
+ t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
196
+ d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
197
+
198
+ # select top-k nearset buckets
199
+ l_topk, l_label = l_offsets.abs().topk(
200
+ offset_topk, dim=1, largest=False, sorted=True)
201
+ r_topk, r_label = r_offsets.abs().topk(
202
+ offset_topk, dim=1, largest=False, sorted=True)
203
+ t_topk, t_label = t_offsets.abs().topk(
204
+ offset_topk, dim=1, largest=False, sorted=True)
205
+ d_topk, d_label = d_offsets.abs().topk(
206
+ offset_topk, dim=1, largest=False, sorted=True)
207
+
208
+ offset_l_weights = l_offsets.new_zeros(l_offsets.size())
209
+ offset_r_weights = r_offsets.new_zeros(r_offsets.size())
210
+ offset_t_weights = t_offsets.new_zeros(t_offsets.size())
211
+ offset_d_weights = d_offsets.new_zeros(d_offsets.size())
212
+ inds = torch.arange(0, proposals.size(0)).to(proposals).long()
213
+
214
+ # generate offset weights of top-k nearset buckets
215
+ for k in range(offset_topk):
216
+ if k >= 1:
217
+ offset_l_weights[inds, l_label[:,
218
+ k]] = (l_topk[:, k] <
219
+ offset_upperbound).float()
220
+ offset_r_weights[inds, r_label[:,
221
+ k]] = (r_topk[:, k] <
222
+ offset_upperbound).float()
223
+ offset_t_weights[inds, t_label[:,
224
+ k]] = (t_topk[:, k] <
225
+ offset_upperbound).float()
226
+ offset_d_weights[inds, d_label[:,
227
+ k]] = (d_topk[:, k] <
228
+ offset_upperbound).float()
229
+ else:
230
+ offset_l_weights[inds, l_label[:, k]] = 1.0
231
+ offset_r_weights[inds, r_label[:, k]] = 1.0
232
+ offset_t_weights[inds, t_label[:, k]] = 1.0
233
+ offset_d_weights[inds, d_label[:, k]] = 1.0
234
+
235
+ offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
236
+ offsets_weights = torch.cat([
237
+ offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
238
+ ],
239
+ dim=-1)
240
+
241
+ # generate bucket labels and weight
242
+ side_num = int(np.ceil(num_buckets / 2.0))
243
+ labels = torch.stack(
244
+ [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
245
+
246
+ batch_size = labels.size(0)
247
+ bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
248
+ -1).float()
249
+ bucket_cls_l_weights = (l_offsets.abs() < 1).float()
250
+ bucket_cls_r_weights = (r_offsets.abs() < 1).float()
251
+ bucket_cls_t_weights = (t_offsets.abs() < 1).float()
252
+ bucket_cls_d_weights = (d_offsets.abs() < 1).float()
253
+ bucket_cls_weights = torch.cat([
254
+ bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
255
+ bucket_cls_d_weights
256
+ ],
257
+ dim=-1)
258
+ # ignore second nearest buckets for cls if necessary
259
+ if cls_ignore_neighbor:
260
+ bucket_cls_weights = (~((bucket_cls_weights == 1) &
261
+ (bucket_labels == 0))).float()
262
+ else:
263
+ bucket_cls_weights[:] = 1.0
264
+ return offsets, offsets_weights, bucket_labels, bucket_cls_weights
265
+
266
+
267
+ @mmcv.jit(coderize=True)
268
+ def bucket2bbox(proposals,
269
+ cls_preds,
270
+ offset_preds,
271
+ num_buckets,
272
+ scale_factor=1.0,
273
+ max_shape=None,
274
+ clip_border=True):
275
+ """Apply bucketing estimation (cls preds) and fine regression (offset
276
+ preds) to generate det bboxes.
277
+
278
+ Args:
279
+ proposals (Tensor): Boxes to be transformed. Shape (n, 4)
280
+ cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
281
+ offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
282
+ num_buckets (int): Number of buckets.
283
+ scale_factor (float): Scale factor to rescale proposals.
284
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
285
+ clip_border (bool, optional): Whether clip the objects outside the
286
+ border of the image. Defaults to True.
287
+
288
+ Returns:
289
+ tuple[Tensor]: (bboxes, loc_confidence).
290
+
291
+ - bboxes: predicted bboxes. Shape (n, 4)
292
+ - loc_confidence: localization confidence of predicted bboxes.
293
+ Shape (n,).
294
+ """
295
+
296
+ side_num = int(np.ceil(num_buckets / 2.0))
297
+ cls_preds = cls_preds.view(-1, side_num)
298
+ offset_preds = offset_preds.view(-1, side_num)
299
+
300
+ scores = F.softmax(cls_preds, dim=1)
301
+ score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
302
+
303
+ rescaled_proposals = bbox_rescale(proposals, scale_factor)
304
+
305
+ pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
306
+ ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
307
+ px1 = rescaled_proposals[..., 0]
308
+ py1 = rescaled_proposals[..., 1]
309
+ px2 = rescaled_proposals[..., 2]
310
+ py2 = rescaled_proposals[..., 3]
311
+
312
+ bucket_w = pw / num_buckets
313
+ bucket_h = ph / num_buckets
314
+
315
+ score_inds_l = score_label[0::4, 0]
316
+ score_inds_r = score_label[1::4, 0]
317
+ score_inds_t = score_label[2::4, 0]
318
+ score_inds_d = score_label[3::4, 0]
319
+ l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
320
+ r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
321
+ t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
322
+ d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
323
+
324
+ offsets = offset_preds.view(-1, 4, side_num)
325
+ inds = torch.arange(proposals.size(0)).to(proposals).long()
326
+ l_offsets = offsets[:, 0, :][inds, score_inds_l]
327
+ r_offsets = offsets[:, 1, :][inds, score_inds_r]
328
+ t_offsets = offsets[:, 2, :][inds, score_inds_t]
329
+ d_offsets = offsets[:, 3, :][inds, score_inds_d]
330
+
331
+ x1 = l_buckets - l_offsets * bucket_w
332
+ x2 = r_buckets - r_offsets * bucket_w
333
+ y1 = t_buckets - t_offsets * bucket_h
334
+ y2 = d_buckets - d_offsets * bucket_h
335
+
336
+ if clip_border and max_shape is not None:
337
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
338
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
339
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
340
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
341
+ bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
342
+ dim=-1)
343
+
344
+ # bucketing guided rescoring
345
+ loc_confidence = score_topk[:, 0]
346
+ top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
347
+ loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
348
+ loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
349
+
350
+ return bboxes, loc_confidence
mmdet/core/bbox/coder/delta_xywh_bbox_coder.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+ import torch
4
+
5
+ from ..builder import BBOX_CODERS
6
+ from .base_bbox_coder import BaseBBoxCoder
7
+
8
+
9
+ @BBOX_CODERS.register_module()
10
+ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
11
+ """Delta XYWH BBox coder.
12
+
13
+ Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
14
+ this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
15
+ decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
16
+
17
+ Args:
18
+ target_means (Sequence[float]): Denormalizing means of target for
19
+ delta coordinates
20
+ target_stds (Sequence[float]): Denormalizing standard deviation of
21
+ target for delta coordinates
22
+ clip_border (bool, optional): Whether clip the objects outside the
23
+ border of the image. Defaults to True.
24
+ """
25
+
26
+ def __init__(self,
27
+ target_means=(0., 0., 0., 0.),
28
+ target_stds=(1., 1., 1., 1.),
29
+ clip_border=True):
30
+ super(BaseBBoxCoder, self).__init__()
31
+ self.means = target_means
32
+ self.stds = target_stds
33
+ self.clip_border = clip_border
34
+
35
+ def encode(self, bboxes, gt_bboxes):
36
+ """Get box regression transformation deltas that can be used to
37
+ transform the ``bboxes`` into the ``gt_bboxes``.
38
+
39
+ Args:
40
+ bboxes (torch.Tensor): Source boxes, e.g., object proposals.
41
+ gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
42
+ ground-truth boxes.
43
+
44
+ Returns:
45
+ torch.Tensor: Box transformation deltas
46
+ """
47
+
48
+ assert bboxes.size(0) == gt_bboxes.size(0)
49
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
50
+ encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
51
+ return encoded_bboxes
52
+
53
+ def decode(self,
54
+ bboxes,
55
+ pred_bboxes,
56
+ max_shape=None,
57
+ wh_ratio_clip=16 / 1000):
58
+ """Apply transformation `pred_bboxes` to `boxes`.
59
+
60
+ Args:
61
+ bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
62
+ pred_bboxes (Tensor): Encoded offsets with respect to each roi.
63
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
64
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
65
+ when rois is a grid of anchors.Offset encoding follows [1]_.
66
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
67
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
68
+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
69
+ the max_shape should be a Sequence[Sequence[int]]
70
+ and the length of max_shape should also be B.
71
+ wh_ratio_clip (float, optional): The allowed ratio between
72
+ width and height.
73
+
74
+ Returns:
75
+ torch.Tensor: Decoded boxes.
76
+ """
77
+
78
+ assert pred_bboxes.size(0) == bboxes.size(0)
79
+ if pred_bboxes.ndim == 3:
80
+ assert pred_bboxes.size(1) == bboxes.size(1)
81
+ decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
82
+ max_shape, wh_ratio_clip, self.clip_border)
83
+
84
+ return decoded_bboxes
85
+
86
+
87
+ @mmcv.jit(coderize=True)
88
+ def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
89
+ """Compute deltas of proposals w.r.t. gt.
90
+
91
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
92
+ truth bboxes to get regression target.
93
+ This is the inverse function of :func:`delta2bbox`.
94
+
95
+ Args:
96
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
97
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
98
+ means (Sequence[float]): Denormalizing means for delta coordinates
99
+ stds (Sequence[float]): Denormalizing standard deviation for delta
100
+ coordinates
101
+
102
+ Returns:
103
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
104
+ dw, dh.
105
+ """
106
+ assert proposals.size() == gt.size()
107
+
108
+ proposals = proposals.float()
109
+ gt = gt.float()
110
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
111
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
112
+ pw = proposals[..., 2] - proposals[..., 0]
113
+ ph = proposals[..., 3] - proposals[..., 1]
114
+
115
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
116
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
117
+ gw = gt[..., 2] - gt[..., 0]
118
+ gh = gt[..., 3] - gt[..., 1]
119
+
120
+ dx = (gx - px) / pw
121
+ dy = (gy - py) / ph
122
+ dw = torch.log(gw / pw)
123
+ dh = torch.log(gh / ph)
124
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
125
+
126
+ means = deltas.new_tensor(means).unsqueeze(0)
127
+ stds = deltas.new_tensor(stds).unsqueeze(0)
128
+ deltas = deltas.sub_(means).div_(stds)
129
+
130
+ return deltas
131
+
132
+
133
+ @mmcv.jit(coderize=True)
134
+ def delta2bbox(rois,
135
+ deltas,
136
+ means=(0., 0., 0., 0.),
137
+ stds=(1., 1., 1., 1.),
138
+ max_shape=None,
139
+ wh_ratio_clip=16 / 1000,
140
+ clip_border=True):
141
+ """Apply deltas to shift/scale base boxes.
142
+
143
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
144
+ network outputs used to shift/scale those boxes.
145
+ This is the inverse function of :func:`bbox2delta`.
146
+
147
+ Args:
148
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
149
+ deltas (Tensor): Encoded offsets with respect to each roi.
150
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
151
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
152
+ when rois is a grid of anchors.Offset encoding follows [1]_.
153
+ means (Sequence[float]): Denormalizing means for delta coordinates
154
+ stds (Sequence[float]): Denormalizing standard deviation for delta
155
+ coordinates
156
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
157
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
158
+ (H, W, C) or (H, W). If rois shape is (B, N, 4), then
159
+ the max_shape should be a Sequence[Sequence[int]]
160
+ and the length of max_shape should also be B.
161
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
162
+ clip_border (bool, optional): Whether clip the objects outside the
163
+ border of the image. Defaults to True.
164
+
165
+ Returns:
166
+ Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
167
+ (N, num_classes * 4) or (N, 4), where 4 represent
168
+ tl_x, tl_y, br_x, br_y.
169
+
170
+ References:
171
+ .. [1] https://arxiv.org/abs/1311.2524
172
+
173
+ Example:
174
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
175
+ >>> [ 0., 0., 1., 1.],
176
+ >>> [ 0., 0., 1., 1.],
177
+ >>> [ 5., 5., 5., 5.]])
178
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
179
+ >>> [ 1., 1., 1., 1.],
180
+ >>> [ 0., 0., 2., -1.],
181
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
182
+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
183
+ tensor([[0.0000, 0.0000, 1.0000, 1.0000],
184
+ [0.1409, 0.1409, 2.8591, 2.8591],
185
+ [0.0000, 0.3161, 4.1945, 0.6839],
186
+ [5.0000, 5.0000, 5.0000, 5.0000]])
187
+ """
188
+ means = deltas.new_tensor(means).view(1,
189
+ -1).repeat(1,
190
+ deltas.size(-1) // 4)
191
+ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
192
+ denorm_deltas = deltas * stds + means
193
+ dx = denorm_deltas[..., 0::4]
194
+ dy = denorm_deltas[..., 1::4]
195
+ dw = denorm_deltas[..., 2::4]
196
+ dh = denorm_deltas[..., 3::4]
197
+ max_ratio = np.abs(np.log(wh_ratio_clip))
198
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
199
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
200
+ x1, y1 = rois[..., 0], rois[..., 1]
201
+ x2, y2 = rois[..., 2], rois[..., 3]
202
+ # Compute center of each roi
203
+ px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
204
+ py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
205
+ # Compute width/height of each roi
206
+ pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
207
+ ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
208
+ # Use exp(network energy) to enlarge/shrink each roi
209
+ gw = pw * dw.exp()
210
+ gh = ph * dh.exp()
211
+ # Use network energy to shift the center of each roi
212
+ gx = px + pw * dx
213
+ gy = py + ph * dy
214
+ # Convert center-xy/width/height to top-left, bottom-right
215
+ x1 = gx - gw * 0.5
216
+ y1 = gy - gh * 0.5
217
+ x2 = gx + gw * 0.5
218
+ y2 = gy + gh * 0.5
219
+
220
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
221
+
222
+ if clip_border and max_shape is not None:
223
+ if not isinstance(max_shape, torch.Tensor):
224
+ max_shape = x1.new_tensor(max_shape)
225
+ max_shape = max_shape[..., :2].type_as(x1)
226
+ if max_shape.ndim == 2:
227
+ assert bboxes.ndim == 3
228
+ assert max_shape.size(0) == bboxes.size(0)
229
+
230
+ min_xy = x1.new_tensor(0)
231
+ max_xy = torch.cat(
232
+ [max_shape] * (deltas.size(-1) // 2),
233
+ dim=-1).flip(-1).unsqueeze(-2)
234
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
235
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
236
+
237
+ return bboxes
mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+ import torch
4
+
5
+ from ..builder import BBOX_CODERS
6
+ from .base_bbox_coder import BaseBBoxCoder
7
+
8
+
9
+ @BBOX_CODERS.register_module()
10
+ class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
11
+ """Legacy Delta XYWH BBox coder used in MMDet V1.x.
12
+
13
+ Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
14
+ y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
15
+ back to original bbox (x1, y1, x2, y2).
16
+
17
+ Note:
18
+ The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
19
+ :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
20
+ height calculation. We suggest to only use this coder when testing with
21
+ MMDet V1.x models.
22
+
23
+ References:
24
+ .. [1] https://arxiv.org/abs/1311.2524
25
+
26
+ Args:
27
+ target_means (Sequence[float]): denormalizing means of target for
28
+ delta coordinates
29
+ target_stds (Sequence[float]): denormalizing standard deviation of
30
+ target for delta coordinates
31
+ """
32
+
33
+ def __init__(self,
34
+ target_means=(0., 0., 0., 0.),
35
+ target_stds=(1., 1., 1., 1.)):
36
+ super(BaseBBoxCoder, self).__init__()
37
+ self.means = target_means
38
+ self.stds = target_stds
39
+
40
+ def encode(self, bboxes, gt_bboxes):
41
+ """Get box regression transformation deltas that can be used to
42
+ transform the ``bboxes`` into the ``gt_bboxes``.
43
+
44
+ Args:
45
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
46
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
47
+ ground-truth boxes.
48
+
49
+ Returns:
50
+ torch.Tensor: Box transformation deltas
51
+ """
52
+ assert bboxes.size(0) == gt_bboxes.size(0)
53
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
54
+ encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
55
+ self.stds)
56
+ return encoded_bboxes
57
+
58
+ def decode(self,
59
+ bboxes,
60
+ pred_bboxes,
61
+ max_shape=None,
62
+ wh_ratio_clip=16 / 1000):
63
+ """Apply transformation `pred_bboxes` to `boxes`.
64
+
65
+ Args:
66
+ boxes (torch.Tensor): Basic boxes.
67
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
68
+ max_shape (tuple[int], optional): Maximum shape of boxes.
69
+ Defaults to None.
70
+ wh_ratio_clip (float, optional): The allowed ratio between
71
+ width and height.
72
+
73
+ Returns:
74
+ torch.Tensor: Decoded boxes.
75
+ """
76
+ assert pred_bboxes.size(0) == bboxes.size(0)
77
+ decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
78
+ self.stds, max_shape, wh_ratio_clip)
79
+
80
+ return decoded_bboxes
81
+
82
+
83
+ @mmcv.jit(coderize=True)
84
+ def legacy_bbox2delta(proposals,
85
+ gt,
86
+ means=(0., 0., 0., 0.),
87
+ stds=(1., 1., 1., 1.)):
88
+ """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
89
+
90
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
91
+ truth bboxes to get regression target.
92
+ This is the inverse function of `delta2bbox()`
93
+
94
+ Args:
95
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
96
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
97
+ means (Sequence[float]): Denormalizing means for delta coordinates
98
+ stds (Sequence[float]): Denormalizing standard deviation for delta
99
+ coordinates
100
+
101
+ Returns:
102
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
103
+ dw, dh.
104
+ """
105
+ assert proposals.size() == gt.size()
106
+
107
+ proposals = proposals.float()
108
+ gt = gt.float()
109
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
110
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
111
+ pw = proposals[..., 2] - proposals[..., 0] + 1.0
112
+ ph = proposals[..., 3] - proposals[..., 1] + 1.0
113
+
114
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
115
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
116
+ gw = gt[..., 2] - gt[..., 0] + 1.0
117
+ gh = gt[..., 3] - gt[..., 1] + 1.0
118
+
119
+ dx = (gx - px) / pw
120
+ dy = (gy - py) / ph
121
+ dw = torch.log(gw / pw)
122
+ dh = torch.log(gh / ph)
123
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
124
+
125
+ means = deltas.new_tensor(means).unsqueeze(0)
126
+ stds = deltas.new_tensor(stds).unsqueeze(0)
127
+ deltas = deltas.sub_(means).div_(stds)
128
+
129
+ return deltas
130
+
131
+
132
+ @mmcv.jit(coderize=True)
133
+ def legacy_delta2bbox(rois,
134
+ deltas,
135
+ means=(0., 0., 0., 0.),
136
+ stds=(1., 1., 1., 1.),
137
+ max_shape=None,
138
+ wh_ratio_clip=16 / 1000):
139
+ """Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
140
+
141
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
142
+ network outputs used to shift/scale those boxes.
143
+ This is the inverse function of `bbox2delta()`
144
+
145
+ Args:
146
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4)
147
+ deltas (Tensor): Encoded offsets with respect to each roi.
148
+ Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
149
+ rois is a grid of anchors. Offset encoding follows [1]_.
150
+ means (Sequence[float]): Denormalizing means for delta coordinates
151
+ stds (Sequence[float]): Denormalizing standard deviation for delta
152
+ coordinates
153
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
154
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
155
+
156
+ Returns:
157
+ Tensor: Boxes with shape (N, 4), where columns represent
158
+ tl_x, tl_y, br_x, br_y.
159
+
160
+ References:
161
+ .. [1] https://arxiv.org/abs/1311.2524
162
+
163
+ Example:
164
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
165
+ >>> [ 0., 0., 1., 1.],
166
+ >>> [ 0., 0., 1., 1.],
167
+ >>> [ 5., 5., 5., 5.]])
168
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
169
+ >>> [ 1., 1., 1., 1.],
170
+ >>> [ 0., 0., 2., -1.],
171
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
172
+ >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
173
+ tensor([[0.0000, 0.0000, 1.5000, 1.5000],
174
+ [0.0000, 0.0000, 5.2183, 5.2183],
175
+ [0.0000, 0.1321, 7.8891, 0.8679],
176
+ [5.3967, 2.4251, 6.0033, 3.7749]])
177
+ """
178
+ means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
179
+ stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
180
+ denorm_deltas = deltas * stds + means
181
+ dx = denorm_deltas[:, 0::4]
182
+ dy = denorm_deltas[:, 1::4]
183
+ dw = denorm_deltas[:, 2::4]
184
+ dh = denorm_deltas[:, 3::4]
185
+ max_ratio = np.abs(np.log(wh_ratio_clip))
186
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
187
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
188
+ # Compute center of each roi
189
+ px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
190
+ py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
191
+ # Compute width/height of each roi
192
+ pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
193
+ ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
194
+ # Use exp(network energy) to enlarge/shrink each roi
195
+ gw = pw * dw.exp()
196
+ gh = ph * dh.exp()
197
+ # Use network energy to shift the center of each roi
198
+ gx = px + pw * dx
199
+ gy = py + ph * dy
200
+ # Convert center-xy/width/height to top-left, bottom-right
201
+
202
+ # The true legacy box coder should +- 0.5 here.
203
+ # However, current implementation improves the performance when testing
204
+ # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
205
+ x1 = gx - gw * 0.5
206
+ y1 = gy - gh * 0.5
207
+ x2 = gx + gw * 0.5
208
+ y2 = gy + gh * 0.5
209
+ if max_shape is not None:
210
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
211
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
212
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
213
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
214
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
215
+ return bboxes