working version 1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +122 -0
- configs/_base_/datasets/parking_instance.py +48 -0
- configs/_base_/datasets/parking_instance_coco.py +49 -0
- configs/_base_/datasets/people_real_coco.py +49 -0
- configs/_base_/datasets/walt_people.py +49 -0
- configs/_base_/datasets/walt_vehicle.py +49 -0
- configs/_base_/default_runtime.py +16 -0
- configs/_base_/models/mask_rcnn_swin_fpn.py +127 -0
- configs/_base_/models/occ_mask_rcnn_swin_fpn.py +127 -0
- configs/_base_/schedules/schedule_1x.py +11 -0
- configs/walt/walt_people.py +80 -0
- configs/walt/walt_vehicle.py +80 -0
- docker/Dockerfile +52 -0
- github_vis/cwalt.gif +0 -0
- github_vis/vis_cars.gif +0 -0
- github_vis/vis_people.gif +0 -0
- mmcv_custom/__init__.py +5 -0
- mmcv_custom/checkpoint.py +500 -0
- mmcv_custom/runner/__init__.py +8 -0
- mmcv_custom/runner/checkpoint.py +85 -0
- mmcv_custom/runner/epoch_based_runner.py +104 -0
- mmdet/__init__.py +28 -0
- mmdet/apis/__init__.py +10 -0
- mmdet/apis/inference.py +217 -0
- mmdet/apis/test.py +189 -0
- mmdet/apis/train.py +185 -0
- mmdet/core/__init__.py +7 -0
- mmdet/core/anchor/__init__.py +11 -0
- mmdet/core/anchor/anchor_generator.py +727 -0
- mmdet/core/anchor/builder.py +7 -0
- mmdet/core/anchor/point_generator.py +37 -0
- mmdet/core/anchor/utils.py +71 -0
- mmdet/core/bbox/__init__.py +27 -0
- mmdet/core/bbox/assigners/__init__.py +16 -0
- mmdet/core/bbox/assigners/approx_max_iou_assigner.py +145 -0
- mmdet/core/bbox/assigners/assign_result.py +204 -0
- mmdet/core/bbox/assigners/atss_assigner.py +178 -0
- mmdet/core/bbox/assigners/base_assigner.py +9 -0
- mmdet/core/bbox/assigners/center_region_assigner.py +335 -0
- mmdet/core/bbox/assigners/grid_assigner.py +155 -0
- mmdet/core/bbox/assigners/hungarian_assigner.py +145 -0
- mmdet/core/bbox/assigners/max_iou_assigner.py +212 -0
- mmdet/core/bbox/assigners/point_assigner.py +133 -0
- mmdet/core/bbox/assigners/region_assigner.py +221 -0
- mmdet/core/bbox/builder.py +20 -0
- mmdet/core/bbox/coder/__init__.py +13 -0
- mmdet/core/bbox/coder/base_bbox_coder.py +17 -0
- mmdet/core/bbox/coder/bucketing_bbox_coder.py +350 -0
- mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +237 -0
- mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py +215 -0
.gitignore
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
|
106 |
+
data/
|
107 |
+
data
|
108 |
+
.vscode
|
109 |
+
.idea
|
110 |
+
.DS_Store
|
111 |
+
|
112 |
+
# custom
|
113 |
+
*.pkl
|
114 |
+
*.pkl.json
|
115 |
+
*.log.json
|
116 |
+
work_dirs/
|
117 |
+
|
118 |
+
# Pytorch
|
119 |
+
*.pth
|
120 |
+
*.py~
|
121 |
+
*.sh~
|
122 |
+
|
configs/_base_/datasets/parking_instance.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'ParkingDataset'
|
2 |
+
data_root = 'data/parking/'
|
3 |
+
img_norm_cfg = dict(
|
4 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
5 |
+
train_pipeline = [
|
6 |
+
dict(type='LoadImageFromFile'),
|
7 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
8 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
9 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
10 |
+
dict(type='Normalize', **img_norm_cfg),
|
11 |
+
dict(type='Pad', size_divisor=32),
|
12 |
+
dict(type='DefaultFormatBundle'),
|
13 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d','gt_bboxes_3d_proj']),
|
14 |
+
]
|
15 |
+
test_pipeline = [
|
16 |
+
dict(type='LoadImageFromFile'),
|
17 |
+
dict(
|
18 |
+
type='MultiScaleFlipAug',
|
19 |
+
img_scale=(1333, 800),
|
20 |
+
flip=False,
|
21 |
+
transforms=[
|
22 |
+
dict(type='Resize', keep_ratio=True),
|
23 |
+
dict(type='RandomFlip'),
|
24 |
+
dict(type='Normalize', **img_norm_cfg),
|
25 |
+
dict(type='Pad', size_divisor=32),
|
26 |
+
dict(type='ImageToTensor', keys=['img']),
|
27 |
+
dict(type='Collect', keys=['img']),
|
28 |
+
])
|
29 |
+
]
|
30 |
+
data = dict(
|
31 |
+
samples_per_gpu=1,
|
32 |
+
workers_per_gpu=1,
|
33 |
+
train=dict(
|
34 |
+
type=dataset_type,
|
35 |
+
ann_file=data_root + 'GT_data/',
|
36 |
+
img_prefix=data_root + 'images/',
|
37 |
+
pipeline=train_pipeline),
|
38 |
+
val=dict(
|
39 |
+
type=dataset_type,
|
40 |
+
ann_file=data_root + 'GT_data/',
|
41 |
+
img_prefix=data_root + 'images/',
|
42 |
+
pipeline=test_pipeline),
|
43 |
+
test=dict(
|
44 |
+
type=dataset_type,
|
45 |
+
ann_file=data_root + 'GT_data/',
|
46 |
+
img_prefix=data_root + 'images/',
|
47 |
+
pipeline=test_pipeline))
|
48 |
+
evaluation = dict(metric=['bbox'])#, 'segm'])
|
configs/_base_/datasets/parking_instance_coco.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'ParkingCocoDataset'
|
2 |
+
data_root = 'data/parking/'
|
3 |
+
data_root_test = 'data/parking_highres/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=6,
|
33 |
+
workers_per_gpu=6,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + 'GT_data/',
|
37 |
+
img_prefix=data_root + 'images/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + 'GT_data/',
|
42 |
+
img_prefix=data_root_test + 'images',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + 'GT_data/',
|
47 |
+
img_prefix=data_root_test + 'images',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/people_real_coco.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=8,
|
33 |
+
workers_per_gpu=8,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/walt_people.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=8,
|
33 |
+
workers_per_gpu=8,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/walt_vehicle.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=5,
|
33 |
+
workers_per_gpu=5,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/default_runtime.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_config = dict(interval=1)
|
2 |
+
# yapf:disable
|
3 |
+
log_config = dict(
|
4 |
+
interval=50,
|
5 |
+
hooks=[
|
6 |
+
dict(type='TextLoggerHook'),
|
7 |
+
# dict(type='TensorboardLoggerHook')
|
8 |
+
])
|
9 |
+
# yapf:enable
|
10 |
+
custom_hooks = [dict(type='NumClassCheckHook')]
|
11 |
+
|
12 |
+
dist_params = dict(backend='nccl')
|
13 |
+
log_level = 'INFO'
|
14 |
+
load_from = None
|
15 |
+
resume_from = None
|
16 |
+
workflow = [('train', 1)]
|
configs/_base_/models/mask_rcnn_swin_fpn.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model settings
|
2 |
+
model = dict(
|
3 |
+
type='MaskRCNN',
|
4 |
+
pretrained=None,
|
5 |
+
backbone=dict(
|
6 |
+
type='SwinTransformer',
|
7 |
+
embed_dim=96,
|
8 |
+
depths=[2, 2, 6, 2],
|
9 |
+
num_heads=[3, 6, 12, 24],
|
10 |
+
window_size=7,
|
11 |
+
mlp_ratio=4.,
|
12 |
+
qkv_bias=True,
|
13 |
+
qk_scale=None,
|
14 |
+
drop_rate=0.,
|
15 |
+
attn_drop_rate=0.,
|
16 |
+
drop_path_rate=0.2,
|
17 |
+
ape=False,
|
18 |
+
patch_norm=True,
|
19 |
+
out_indices=(0, 1, 2, 3),
|
20 |
+
use_checkpoint=False),
|
21 |
+
neck=dict(
|
22 |
+
type='FPN',
|
23 |
+
in_channels=[96, 192, 384, 768],
|
24 |
+
out_channels=256,
|
25 |
+
num_outs=5),
|
26 |
+
rpn_head=dict(
|
27 |
+
type='RPNHead',
|
28 |
+
in_channels=256,
|
29 |
+
feat_channels=256,
|
30 |
+
anchor_generator=dict(
|
31 |
+
type='AnchorGenerator',
|
32 |
+
scales=[8],
|
33 |
+
ratios=[0.5, 1.0, 2.0],
|
34 |
+
strides=[4, 8, 16, 32, 64]),
|
35 |
+
bbox_coder=dict(
|
36 |
+
type='DeltaXYWHBBoxCoder',
|
37 |
+
target_means=[.0, .0, .0, .0],
|
38 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
39 |
+
loss_cls=dict(
|
40 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
41 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
42 |
+
roi_head=dict(
|
43 |
+
type='StandardRoIHead',
|
44 |
+
bbox_roi_extractor=dict(
|
45 |
+
type='SingleRoIExtractor',
|
46 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
47 |
+
out_channels=256,
|
48 |
+
featmap_strides=[4, 8, 16, 32]),
|
49 |
+
bbox_head=dict(
|
50 |
+
type='Shared2FCBBoxHead',
|
51 |
+
in_channels=256,
|
52 |
+
fc_out_channels=1024,
|
53 |
+
roi_feat_size=7,
|
54 |
+
num_classes=80,
|
55 |
+
bbox_coder=dict(
|
56 |
+
type='DeltaXYWHBBoxCoder',
|
57 |
+
target_means=[0., 0., 0., 0.],
|
58 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
59 |
+
reg_class_agnostic=False,
|
60 |
+
loss_cls=dict(
|
61 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
62 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
63 |
+
mask_roi_extractor=dict(
|
64 |
+
type='SingleRoIExtractor',
|
65 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
66 |
+
out_channels=256,
|
67 |
+
featmap_strides=[4, 8, 16, 32]),
|
68 |
+
mask_head=dict(
|
69 |
+
type='FCNMaskHead',
|
70 |
+
num_convs=4,
|
71 |
+
in_channels=256,
|
72 |
+
conv_out_channels=256,
|
73 |
+
num_classes=80,
|
74 |
+
loss_mask=dict(
|
75 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
76 |
+
# model training and testing settings
|
77 |
+
train_cfg=dict(
|
78 |
+
rpn=dict(
|
79 |
+
assigner=dict(
|
80 |
+
type='MaxIoUAssigner',
|
81 |
+
pos_iou_thr=0.7,
|
82 |
+
neg_iou_thr=0.3,
|
83 |
+
min_pos_iou=0.3,
|
84 |
+
match_low_quality=True,
|
85 |
+
ignore_iof_thr=-1),
|
86 |
+
sampler=dict(
|
87 |
+
type='RandomSampler',
|
88 |
+
num=256,
|
89 |
+
pos_fraction=0.5,
|
90 |
+
neg_pos_ub=-1,
|
91 |
+
add_gt_as_proposals=False),
|
92 |
+
allowed_border=-1,
|
93 |
+
pos_weight=-1,
|
94 |
+
debug=False),
|
95 |
+
rpn_proposal=dict(
|
96 |
+
nms_pre=2000,
|
97 |
+
max_per_img=1000,
|
98 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
99 |
+
min_bbox_size=0),
|
100 |
+
rcnn=dict(
|
101 |
+
assigner=dict(
|
102 |
+
type='MaxIoUAssigner',
|
103 |
+
pos_iou_thr=0.5,
|
104 |
+
neg_iou_thr=0.5,
|
105 |
+
min_pos_iou=0.5,
|
106 |
+
match_low_quality=True,
|
107 |
+
ignore_iof_thr=-1),
|
108 |
+
sampler=dict(
|
109 |
+
type='RandomSampler',
|
110 |
+
num=512,
|
111 |
+
pos_fraction=0.25,
|
112 |
+
neg_pos_ub=-1,
|
113 |
+
add_gt_as_proposals=True),
|
114 |
+
mask_size=28,
|
115 |
+
pos_weight=-1,
|
116 |
+
debug=False)),
|
117 |
+
test_cfg=dict(
|
118 |
+
rpn=dict(
|
119 |
+
nms_pre=1000,
|
120 |
+
max_per_img=1000,
|
121 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
122 |
+
min_bbox_size=0),
|
123 |
+
rcnn=dict(
|
124 |
+
score_thr=0.05,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
126 |
+
max_per_img=100,
|
127 |
+
mask_thr_binary=0.5)))
|
configs/_base_/models/occ_mask_rcnn_swin_fpn.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model settings
|
2 |
+
model = dict(
|
3 |
+
type='MaskRCNN',
|
4 |
+
pretrained=None,
|
5 |
+
backbone=dict(
|
6 |
+
type='SwinTransformer',
|
7 |
+
embed_dim=96,
|
8 |
+
depths=[2, 2, 6, 2],
|
9 |
+
num_heads=[3, 6, 12, 24],
|
10 |
+
window_size=7,
|
11 |
+
mlp_ratio=4.,
|
12 |
+
qkv_bias=True,
|
13 |
+
qk_scale=None,
|
14 |
+
drop_rate=0.,
|
15 |
+
attn_drop_rate=0.,
|
16 |
+
drop_path_rate=0.2,
|
17 |
+
ape=False,
|
18 |
+
patch_norm=True,
|
19 |
+
out_indices=(0, 1, 2, 3),
|
20 |
+
use_checkpoint=False),
|
21 |
+
neck=dict(
|
22 |
+
type='FPN',
|
23 |
+
in_channels=[96, 192, 384, 768],
|
24 |
+
out_channels=256,
|
25 |
+
num_outs=5),
|
26 |
+
rpn_head=dict(
|
27 |
+
type='RPNHead',
|
28 |
+
in_channels=256,
|
29 |
+
feat_channels=256,
|
30 |
+
anchor_generator=dict(
|
31 |
+
type='AnchorGenerator',
|
32 |
+
scales=[8],
|
33 |
+
ratios=[0.5, 1.0, 2.0],
|
34 |
+
strides=[4, 8, 16, 32, 64]),
|
35 |
+
bbox_coder=dict(
|
36 |
+
type='DeltaXYWHBBoxCoder',
|
37 |
+
target_means=[.0, .0, .0, .0],
|
38 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
39 |
+
loss_cls=dict(
|
40 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
41 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
42 |
+
roi_head=dict(
|
43 |
+
type='StandardRoIHead',
|
44 |
+
bbox_roi_extractor=dict(
|
45 |
+
type='SingleRoIExtractor',
|
46 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
47 |
+
out_channels=256,
|
48 |
+
featmap_strides=[4, 8, 16, 32]),
|
49 |
+
bbox_head=dict(
|
50 |
+
type='Shared2FCBBoxHead',
|
51 |
+
in_channels=256,
|
52 |
+
fc_out_channels=1024,
|
53 |
+
roi_feat_size=7,
|
54 |
+
num_classes=80,
|
55 |
+
bbox_coder=dict(
|
56 |
+
type='DeltaXYWHBBoxCoder',
|
57 |
+
target_means=[0., 0., 0., 0.],
|
58 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
59 |
+
reg_class_agnostic=False,
|
60 |
+
loss_cls=dict(
|
61 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
62 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
63 |
+
mask_roi_extractor=dict(
|
64 |
+
type='SingleRoIExtractor',
|
65 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
66 |
+
out_channels=256,
|
67 |
+
featmap_strides=[4, 8, 16, 32]),
|
68 |
+
mask_head=dict(
|
69 |
+
type='FCNOccMaskHead',
|
70 |
+
num_convs=4,
|
71 |
+
in_channels=256,
|
72 |
+
conv_out_channels=256,
|
73 |
+
num_classes=80,
|
74 |
+
loss_mask=dict(
|
75 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
76 |
+
# model training and testing settings
|
77 |
+
train_cfg=dict(
|
78 |
+
rpn=dict(
|
79 |
+
assigner=dict(
|
80 |
+
type='MaxIoUAssigner',
|
81 |
+
pos_iou_thr=0.7,
|
82 |
+
neg_iou_thr=0.3,
|
83 |
+
min_pos_iou=0.3,
|
84 |
+
match_low_quality=True,
|
85 |
+
ignore_iof_thr=-1),
|
86 |
+
sampler=dict(
|
87 |
+
type='RandomSampler',
|
88 |
+
num=256,
|
89 |
+
pos_fraction=0.5,
|
90 |
+
neg_pos_ub=-1,
|
91 |
+
add_gt_as_proposals=False),
|
92 |
+
allowed_border=-1,
|
93 |
+
pos_weight=-1,
|
94 |
+
debug=False),
|
95 |
+
rpn_proposal=dict(
|
96 |
+
nms_pre=2000,
|
97 |
+
max_per_img=1000,
|
98 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
99 |
+
min_bbox_size=0),
|
100 |
+
rcnn=dict(
|
101 |
+
assigner=dict(
|
102 |
+
type='MaxIoUAssigner',
|
103 |
+
pos_iou_thr=0.5,
|
104 |
+
neg_iou_thr=0.5,
|
105 |
+
min_pos_iou=0.5,
|
106 |
+
match_low_quality=True,
|
107 |
+
ignore_iof_thr=-1),
|
108 |
+
sampler=dict(
|
109 |
+
type='RandomSampler',
|
110 |
+
num=512,
|
111 |
+
pos_fraction=0.25,
|
112 |
+
neg_pos_ub=-1,
|
113 |
+
add_gt_as_proposals=True),
|
114 |
+
mask_size=28,
|
115 |
+
pos_weight=-1,
|
116 |
+
debug=False)),
|
117 |
+
test_cfg=dict(
|
118 |
+
rpn=dict(
|
119 |
+
nms_pre=1000,
|
120 |
+
max_per_img=1000,
|
121 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
122 |
+
min_bbox_size=0),
|
123 |
+
rcnn=dict(
|
124 |
+
score_thr=0.05,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
126 |
+
max_per_img=100,
|
127 |
+
mask_thr_binary=0.5)))
|
configs/_base_/schedules/schedule_1x.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# optimizer
|
2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
3 |
+
optimizer_config = dict(grad_clip=None)
|
4 |
+
# learning policy
|
5 |
+
lr_config = dict(
|
6 |
+
policy='step',
|
7 |
+
warmup='linear',
|
8 |
+
warmup_iters=500,
|
9 |
+
warmup_ratio=0.001,
|
10 |
+
step=[8, 11])
|
11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=12)
|
configs/walt/walt_people.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/occ_mask_rcnn_swin_fpn.py',
|
3 |
+
'../_base_/datasets/walt_people.py',
|
4 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
model = dict(
|
8 |
+
backbone=dict(
|
9 |
+
embed_dim=96,
|
10 |
+
depths=[2, 2, 6, 2],
|
11 |
+
num_heads=[3, 6, 12, 24],
|
12 |
+
window_size=7,
|
13 |
+
ape=False,
|
14 |
+
drop_path_rate=0.1,
|
15 |
+
patch_norm=True,
|
16 |
+
use_checkpoint=False
|
17 |
+
),
|
18 |
+
neck=dict(in_channels=[96, 192, 384, 768]))
|
19 |
+
|
20 |
+
img_norm_cfg = dict(
|
21 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
22 |
+
|
23 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
24 |
+
train_pipeline = [
|
25 |
+
dict(type='LoadImageFromFile'),
|
26 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
27 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
28 |
+
dict(type='AutoAugment',
|
29 |
+
policies=[
|
30 |
+
[
|
31 |
+
dict(type='Resize',
|
32 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
33 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
34 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
35 |
+
multiscale_mode='value',
|
36 |
+
keep_ratio=True)
|
37 |
+
],
|
38 |
+
[
|
39 |
+
dict(type='Resize',
|
40 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
41 |
+
multiscale_mode='value',
|
42 |
+
keep_ratio=True),
|
43 |
+
dict(type='RandomCrop',
|
44 |
+
crop_type='absolute_range',
|
45 |
+
crop_size=(384, 600),
|
46 |
+
allow_negative_crop=True),
|
47 |
+
dict(type='Resize',
|
48 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
49 |
+
(576, 1333), (608, 1333), (640, 1333),
|
50 |
+
(672, 1333), (704, 1333), (736, 1333),
|
51 |
+
(768, 1333), (800, 1333)],
|
52 |
+
multiscale_mode='value',
|
53 |
+
override=True,
|
54 |
+
keep_ratio=True)
|
55 |
+
]
|
56 |
+
]),
|
57 |
+
dict(type='Normalize', **img_norm_cfg),
|
58 |
+
dict(type='Pad', size_divisor=32),
|
59 |
+
dict(type='DefaultFormatBundle'),
|
60 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
61 |
+
]
|
62 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
63 |
+
|
64 |
+
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
65 |
+
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
|
66 |
+
'relative_position_bias_table': dict(decay_mult=0.),
|
67 |
+
'norm': dict(decay_mult=0.)}))
|
68 |
+
lr_config = dict(step=[8, 11])
|
69 |
+
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
|
70 |
+
|
71 |
+
# do not use mmdet version fp16
|
72 |
+
fp16 = None
|
73 |
+
optimizer_config = dict(
|
74 |
+
type="DistOptimizerHook",
|
75 |
+
update_interval=1,
|
76 |
+
grad_clip=None,
|
77 |
+
coalesce=True,
|
78 |
+
bucket_size_mb=-1,
|
79 |
+
use_fp16=True,
|
80 |
+
)
|
configs/walt/walt_vehicle.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/occ_mask_rcnn_swin_fpn.py',
|
3 |
+
'../_base_/datasets/walt_vehicle.py',
|
4 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
model = dict(
|
8 |
+
backbone=dict(
|
9 |
+
embed_dim=96,
|
10 |
+
depths=[2, 2, 6, 2],
|
11 |
+
num_heads=[3, 6, 12, 24],
|
12 |
+
window_size=7,
|
13 |
+
ape=False,
|
14 |
+
drop_path_rate=0.1,
|
15 |
+
patch_norm=True,
|
16 |
+
use_checkpoint=False
|
17 |
+
),
|
18 |
+
neck=dict(in_channels=[96, 192, 384, 768]))
|
19 |
+
|
20 |
+
img_norm_cfg = dict(
|
21 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
22 |
+
|
23 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
24 |
+
train_pipeline = [
|
25 |
+
dict(type='LoadImageFromFile'),
|
26 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
27 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
28 |
+
dict(type='AutoAugment',
|
29 |
+
policies=[
|
30 |
+
[
|
31 |
+
dict(type='Resize',
|
32 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
33 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
34 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
35 |
+
multiscale_mode='value',
|
36 |
+
keep_ratio=True)
|
37 |
+
],
|
38 |
+
[
|
39 |
+
dict(type='Resize',
|
40 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
41 |
+
multiscale_mode='value',
|
42 |
+
keep_ratio=True),
|
43 |
+
dict(type='RandomCrop',
|
44 |
+
crop_type='absolute_range',
|
45 |
+
crop_size=(384, 600),
|
46 |
+
allow_negative_crop=True),
|
47 |
+
dict(type='Resize',
|
48 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
49 |
+
(576, 1333), (608, 1333), (640, 1333),
|
50 |
+
(672, 1333), (704, 1333), (736, 1333),
|
51 |
+
(768, 1333), (800, 1333)],
|
52 |
+
multiscale_mode='value',
|
53 |
+
override=True,
|
54 |
+
keep_ratio=True)
|
55 |
+
]
|
56 |
+
]),
|
57 |
+
dict(type='Normalize', **img_norm_cfg),
|
58 |
+
dict(type='Pad', size_divisor=32),
|
59 |
+
dict(type='DefaultFormatBundle'),
|
60 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
61 |
+
]
|
62 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
63 |
+
|
64 |
+
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
65 |
+
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
|
66 |
+
'relative_position_bias_table': dict(decay_mult=0.),
|
67 |
+
'norm': dict(decay_mult=0.)}))
|
68 |
+
lr_config = dict(step=[8, 11])
|
69 |
+
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
|
70 |
+
|
71 |
+
# do not use mmdet version fp16
|
72 |
+
fp16 = None
|
73 |
+
optimizer_config = dict(
|
74 |
+
type="DistOptimizerHook",
|
75 |
+
update_interval=1,
|
76 |
+
grad_clip=None,
|
77 |
+
coalesce=True,
|
78 |
+
bucket_size_mb=-1,
|
79 |
+
use_fp16=True,
|
80 |
+
)
|
docker/Dockerfile
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG PYTORCH="1.9.0"
|
2 |
+
ARG CUDA="11.1"
|
3 |
+
ARG CUDNN="8"
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
6 |
+
|
7 |
+
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
|
8 |
+
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
|
9 |
+
ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
|
10 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
11 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
12 |
+
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
|
13 |
+
&& apt-get clean \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Install MMCV
|
17 |
+
#RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
|
18 |
+
# -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
|
19 |
+
RUN pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
|
20 |
+
# Install MMDetection
|
21 |
+
RUN conda clean --all
|
22 |
+
RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
|
23 |
+
WORKDIR /mmdetection
|
24 |
+
ENV FORCE_CUDA="1"
|
25 |
+
RUN cd /mmdetection && git checkout 7bd39044f35aec4b90dd797b965777541a8678ff
|
26 |
+
RUN pip install -r requirements/build.txt
|
27 |
+
RUN pip install --no-cache-dir -e .
|
28 |
+
RUN apt-get update
|
29 |
+
RUN apt-get install -y vim
|
30 |
+
RUN pip uninstall -y pycocotools
|
31 |
+
RUN pip install mmpycocotools timm scikit-image imagesize
|
32 |
+
|
33 |
+
|
34 |
+
# make sure we don't overwrite some existing directory called "apex"
|
35 |
+
WORKDIR /tmp/unique_for_apex
|
36 |
+
# uninstall Apex if present, twice to make absolutely sure :)
|
37 |
+
RUN pip uninstall -y apex || :
|
38 |
+
RUN pip uninstall -y apex || :
|
39 |
+
# SHA is something the user can touch to force recreation of this Docker layer,
|
40 |
+
# and therefore force cloning of the latest version of Apex
|
41 |
+
RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
|
42 |
+
WORKDIR /tmp/unique_for_apex/apex
|
43 |
+
RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
44 |
+
RUN pip install seaborn sklearn imantics gradio
|
45 |
+
WORKDIR /code
|
46 |
+
ENTRYPOINT ["python", "app.py"]
|
47 |
+
|
48 |
+
#RUN git clone https://github.com/NVIDIA/apex
|
49 |
+
#RUN cd apex
|
50 |
+
#RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
51 |
+
#RUN pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
52 |
+
|
github_vis/cwalt.gif
ADDED
![]() |
github_vis/vis_cars.gif
ADDED
![]() |
github_vis/vis_people.gif
ADDED
![]() |
mmcv_custom/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .checkpoint import load_checkpoint
|
4 |
+
|
5 |
+
__all__ = ['load_checkpoint']
|
mmcv_custom/checkpoint.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pkgutil
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from collections import OrderedDict
|
9 |
+
from importlib import import_module
|
10 |
+
from tempfile import TemporaryDirectory
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torchvision
|
14 |
+
from torch.optim import Optimizer
|
15 |
+
from torch.utils import model_zoo
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
import mmcv
|
19 |
+
from mmcv.fileio import FileClient
|
20 |
+
from mmcv.fileio import load as load_file
|
21 |
+
from mmcv.parallel import is_module_wrapper
|
22 |
+
from mmcv.utils import mkdir_or_exist
|
23 |
+
from mmcv.runner import get_dist_info
|
24 |
+
|
25 |
+
ENV_MMCV_HOME = 'MMCV_HOME'
|
26 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
27 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
28 |
+
|
29 |
+
|
30 |
+
def _get_mmcv_home():
|
31 |
+
mmcv_home = os.path.expanduser(
|
32 |
+
os.getenv(
|
33 |
+
ENV_MMCV_HOME,
|
34 |
+
os.path.join(
|
35 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
36 |
+
|
37 |
+
mkdir_or_exist(mmcv_home)
|
38 |
+
return mmcv_home
|
39 |
+
|
40 |
+
|
41 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
42 |
+
"""Load state_dict to a module.
|
43 |
+
|
44 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
45 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
46 |
+
param mismatch will be shown even if strict is False.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
module (Module): Module that receives the state_dict.
|
50 |
+
state_dict (OrderedDict): Weights.
|
51 |
+
strict (bool): whether to strictly enforce that the keys
|
52 |
+
in :attr:`state_dict` match the keys returned by this module's
|
53 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
54 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
55 |
+
message. If not specified, print function will be used.
|
56 |
+
"""
|
57 |
+
unexpected_keys = []
|
58 |
+
all_missing_keys = []
|
59 |
+
err_msg = []
|
60 |
+
|
61 |
+
metadata = getattr(state_dict, '_metadata', None)
|
62 |
+
state_dict = state_dict.copy()
|
63 |
+
if metadata is not None:
|
64 |
+
state_dict._metadata = metadata
|
65 |
+
|
66 |
+
# use _load_from_state_dict to enable checkpoint version control
|
67 |
+
def load(module, prefix=''):
|
68 |
+
# recursively check parallel module in case that the model has a
|
69 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
70 |
+
if is_module_wrapper(module):
|
71 |
+
module = module.module
|
72 |
+
local_metadata = {} if metadata is None else metadata.get(
|
73 |
+
prefix[:-1], {})
|
74 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
75 |
+
all_missing_keys, unexpected_keys,
|
76 |
+
err_msg)
|
77 |
+
for name, child in module._modules.items():
|
78 |
+
if child is not None:
|
79 |
+
load(child, prefix + name + '.')
|
80 |
+
|
81 |
+
load(module)
|
82 |
+
load = None # break load->load reference cycle
|
83 |
+
|
84 |
+
# ignore "num_batches_tracked" of BN layers
|
85 |
+
missing_keys = [
|
86 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
87 |
+
]
|
88 |
+
|
89 |
+
if unexpected_keys:
|
90 |
+
err_msg.append('unexpected key in source '
|
91 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
92 |
+
if missing_keys:
|
93 |
+
err_msg.append(
|
94 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
95 |
+
|
96 |
+
rank, _ = get_dist_info()
|
97 |
+
if len(err_msg) > 0 and rank == 0:
|
98 |
+
err_msg.insert(
|
99 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
100 |
+
err_msg = '\n'.join(err_msg)
|
101 |
+
if strict:
|
102 |
+
raise RuntimeError(err_msg)
|
103 |
+
elif logger is not None:
|
104 |
+
logger.warning(err_msg)
|
105 |
+
else:
|
106 |
+
print(err_msg)
|
107 |
+
|
108 |
+
|
109 |
+
def load_url_dist(url, model_dir=None):
|
110 |
+
"""In distributed setting, this function only download checkpoint at local
|
111 |
+
rank 0."""
|
112 |
+
rank, world_size = get_dist_info()
|
113 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
114 |
+
if rank == 0:
|
115 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
116 |
+
if world_size > 1:
|
117 |
+
torch.distributed.barrier()
|
118 |
+
if rank > 0:
|
119 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
120 |
+
return checkpoint
|
121 |
+
|
122 |
+
|
123 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
124 |
+
"""In distributed setting, this function only download checkpoint at local
|
125 |
+
rank 0."""
|
126 |
+
try:
|
127 |
+
from pavi import modelcloud
|
128 |
+
except ImportError:
|
129 |
+
raise ImportError(
|
130 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
131 |
+
rank, world_size = get_dist_info()
|
132 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
133 |
+
if rank == 0:
|
134 |
+
model = modelcloud.get(model_path)
|
135 |
+
with TemporaryDirectory() as tmp_dir:
|
136 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
137 |
+
model.download(downloaded_file)
|
138 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
139 |
+
if world_size > 1:
|
140 |
+
torch.distributed.barrier()
|
141 |
+
if rank > 0:
|
142 |
+
model = modelcloud.get(model_path)
|
143 |
+
with TemporaryDirectory() as tmp_dir:
|
144 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
145 |
+
model.download(downloaded_file)
|
146 |
+
checkpoint = torch.load(
|
147 |
+
downloaded_file, map_location=map_location)
|
148 |
+
return checkpoint
|
149 |
+
|
150 |
+
|
151 |
+
def load_fileclient_dist(filename, backend, map_location):
|
152 |
+
"""In distributed setting, this function only download checkpoint at local
|
153 |
+
rank 0."""
|
154 |
+
rank, world_size = get_dist_info()
|
155 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
156 |
+
allowed_backends = ['ceph']
|
157 |
+
if backend not in allowed_backends:
|
158 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
159 |
+
if rank == 0:
|
160 |
+
fileclient = FileClient(backend=backend)
|
161 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
162 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
163 |
+
if world_size > 1:
|
164 |
+
torch.distributed.barrier()
|
165 |
+
if rank > 0:
|
166 |
+
fileclient = FileClient(backend=backend)
|
167 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
168 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
169 |
+
return checkpoint
|
170 |
+
|
171 |
+
|
172 |
+
def get_torchvision_models():
|
173 |
+
model_urls = dict()
|
174 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
175 |
+
if ispkg:
|
176 |
+
continue
|
177 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
178 |
+
if hasattr(_zoo, 'model_urls'):
|
179 |
+
_urls = getattr(_zoo, 'model_urls')
|
180 |
+
model_urls.update(_urls)
|
181 |
+
return model_urls
|
182 |
+
|
183 |
+
|
184 |
+
def get_external_models():
|
185 |
+
mmcv_home = _get_mmcv_home()
|
186 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
187 |
+
default_urls = load_file(default_json_path)
|
188 |
+
assert isinstance(default_urls, dict)
|
189 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
190 |
+
if osp.exists(external_json_path):
|
191 |
+
external_urls = load_file(external_json_path)
|
192 |
+
assert isinstance(external_urls, dict)
|
193 |
+
default_urls.update(external_urls)
|
194 |
+
|
195 |
+
return default_urls
|
196 |
+
|
197 |
+
|
198 |
+
def get_mmcls_models():
|
199 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
200 |
+
mmcls_urls = load_file(mmcls_json_path)
|
201 |
+
|
202 |
+
return mmcls_urls
|
203 |
+
|
204 |
+
|
205 |
+
def get_deprecated_model_names():
|
206 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
207 |
+
'model_zoo/deprecated.json')
|
208 |
+
deprecate_urls = load_file(deprecate_json_path)
|
209 |
+
assert isinstance(deprecate_urls, dict)
|
210 |
+
|
211 |
+
return deprecate_urls
|
212 |
+
|
213 |
+
|
214 |
+
def _process_mmcls_checkpoint(checkpoint):
|
215 |
+
state_dict = checkpoint['state_dict']
|
216 |
+
new_state_dict = OrderedDict()
|
217 |
+
for k, v in state_dict.items():
|
218 |
+
if k.startswith('backbone.'):
|
219 |
+
new_state_dict[k[9:]] = v
|
220 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
221 |
+
|
222 |
+
return new_checkpoint
|
223 |
+
|
224 |
+
|
225 |
+
def _load_checkpoint(filename, map_location=None):
|
226 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
227 |
+
|
228 |
+
Args:
|
229 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
230 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
231 |
+
details.
|
232 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
236 |
+
OrderedDict storing model weights or a dict containing other
|
237 |
+
information, which depends on the checkpoint.
|
238 |
+
"""
|
239 |
+
if filename.startswith('modelzoo://'):
|
240 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
241 |
+
'use "torchvision://" instead')
|
242 |
+
model_urls = get_torchvision_models()
|
243 |
+
model_name = filename[11:]
|
244 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
245 |
+
elif filename.startswith('torchvision://'):
|
246 |
+
model_urls = get_torchvision_models()
|
247 |
+
model_name = filename[14:]
|
248 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
249 |
+
elif filename.startswith('open-mmlab://'):
|
250 |
+
model_urls = get_external_models()
|
251 |
+
model_name = filename[13:]
|
252 |
+
deprecated_urls = get_deprecated_model_names()
|
253 |
+
if model_name in deprecated_urls:
|
254 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
255 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
256 |
+
model_name = deprecated_urls[model_name]
|
257 |
+
model_url = model_urls[model_name]
|
258 |
+
# check if is url
|
259 |
+
if model_url.startswith(('http://', 'https://')):
|
260 |
+
checkpoint = load_url_dist(model_url)
|
261 |
+
else:
|
262 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
263 |
+
if not osp.isfile(filename):
|
264 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
265 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
266 |
+
elif filename.startswith('mmcls://'):
|
267 |
+
model_urls = get_mmcls_models()
|
268 |
+
model_name = filename[8:]
|
269 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
270 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
271 |
+
elif filename.startswith(('http://', 'https://')):
|
272 |
+
checkpoint = load_url_dist(filename)
|
273 |
+
elif filename.startswith('pavi://'):
|
274 |
+
model_path = filename[7:]
|
275 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
276 |
+
elif filename.startswith('s3://'):
|
277 |
+
checkpoint = load_fileclient_dist(
|
278 |
+
filename, backend='ceph', map_location=map_location)
|
279 |
+
else:
|
280 |
+
if not osp.isfile(filename):
|
281 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
282 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
283 |
+
return checkpoint
|
284 |
+
|
285 |
+
|
286 |
+
def load_checkpoint(model,
|
287 |
+
filename,
|
288 |
+
map_location='cpu',
|
289 |
+
strict=False,
|
290 |
+
logger=None):
|
291 |
+
"""Load checkpoint from a file or URI.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
model (Module): Module to load checkpoint.
|
295 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
296 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
297 |
+
details.
|
298 |
+
map_location (str): Same as :func:`torch.load`.
|
299 |
+
strict (bool): Whether to allow different params for the model and
|
300 |
+
checkpoint.
|
301 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
dict or OrderedDict: The loaded checkpoint.
|
305 |
+
"""
|
306 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
307 |
+
# OrderedDict is a subclass of dict
|
308 |
+
if not isinstance(checkpoint, dict):
|
309 |
+
raise RuntimeError(
|
310 |
+
f'No state_dict found in checkpoint file {filename}')
|
311 |
+
# get state_dict from checkpoint
|
312 |
+
if 'state_dict' in checkpoint:
|
313 |
+
state_dict = checkpoint['state_dict']
|
314 |
+
elif 'model' in checkpoint:
|
315 |
+
state_dict = checkpoint['model']
|
316 |
+
else:
|
317 |
+
state_dict = checkpoint
|
318 |
+
# strip prefix of state_dict
|
319 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
320 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
321 |
+
|
322 |
+
# for MoBY, load model of online branch
|
323 |
+
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
324 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
325 |
+
|
326 |
+
# reshape absolute position embedding
|
327 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
328 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
329 |
+
N1, L, C1 = absolute_pos_embed.size()
|
330 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
331 |
+
if N1 != N2 or C1 != C2 or L != H*W:
|
332 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
333 |
+
else:
|
334 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
335 |
+
|
336 |
+
# interpolate position bias table if needed
|
337 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
338 |
+
for table_key in relative_position_bias_table_keys:
|
339 |
+
table_pretrained = state_dict[table_key]
|
340 |
+
table_current = model.state_dict()[table_key]
|
341 |
+
L1, nH1 = table_pretrained.size()
|
342 |
+
L2, nH2 = table_current.size()
|
343 |
+
if nH1 != nH2:
|
344 |
+
logger.warning(f"Error in loading {table_key}, pass")
|
345 |
+
else:
|
346 |
+
if L1 != L2:
|
347 |
+
S1 = int(L1 ** 0.5)
|
348 |
+
S2 = int(L2 ** 0.5)
|
349 |
+
table_pretrained_resized = F.interpolate(
|
350 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
351 |
+
size=(S2, S2), mode='bicubic')
|
352 |
+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
353 |
+
|
354 |
+
# load state_dict
|
355 |
+
load_state_dict(model, state_dict, strict, logger)
|
356 |
+
return checkpoint
|
357 |
+
|
358 |
+
|
359 |
+
def weights_to_cpu(state_dict):
|
360 |
+
"""Copy a model state_dict to cpu.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
state_dict (OrderedDict): Model weights on GPU.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
OrderedDict: Model weights on GPU.
|
367 |
+
"""
|
368 |
+
state_dict_cpu = OrderedDict()
|
369 |
+
for key, val in state_dict.items():
|
370 |
+
state_dict_cpu[key] = val.cpu()
|
371 |
+
return state_dict_cpu
|
372 |
+
|
373 |
+
|
374 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
375 |
+
"""Saves module state to `destination` dictionary.
|
376 |
+
|
377 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
module (nn.Module): The module to generate state_dict.
|
381 |
+
destination (dict): A dict where state will be stored.
|
382 |
+
prefix (str): The prefix for parameters and buffers used in this
|
383 |
+
module.
|
384 |
+
"""
|
385 |
+
for name, param in module._parameters.items():
|
386 |
+
if param is not None:
|
387 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
388 |
+
for name, buf in module._buffers.items():
|
389 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
390 |
+
if buf is not None:
|
391 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
392 |
+
|
393 |
+
|
394 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
395 |
+
"""Returns a dictionary containing a whole state of the module.
|
396 |
+
|
397 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
398 |
+
included. Keys are corresponding parameter and buffer names.
|
399 |
+
|
400 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
401 |
+
recursively check parallel module in case that the model has a complicated
|
402 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
403 |
+
|
404 |
+
Args:
|
405 |
+
module (nn.Module): The module to generate state_dict.
|
406 |
+
destination (OrderedDict): Returned dict for the state of the
|
407 |
+
module.
|
408 |
+
prefix (str): Prefix of the key.
|
409 |
+
keep_vars (bool): Whether to keep the variable property of the
|
410 |
+
parameters. Default: False.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
dict: A dictionary containing a whole state of the module.
|
414 |
+
"""
|
415 |
+
# recursively check parallel module in case that the model has a
|
416 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
417 |
+
if is_module_wrapper(module):
|
418 |
+
module = module.module
|
419 |
+
|
420 |
+
# below is the same as torch.nn.Module.state_dict()
|
421 |
+
if destination is None:
|
422 |
+
destination = OrderedDict()
|
423 |
+
destination._metadata = OrderedDict()
|
424 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
425 |
+
version=module._version)
|
426 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
427 |
+
for name, child in module._modules.items():
|
428 |
+
if child is not None:
|
429 |
+
get_state_dict(
|
430 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
431 |
+
for hook in module._state_dict_hooks.values():
|
432 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
433 |
+
if hook_result is not None:
|
434 |
+
destination = hook_result
|
435 |
+
return destination
|
436 |
+
|
437 |
+
|
438 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
439 |
+
"""Save checkpoint to file.
|
440 |
+
|
441 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
442 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
model (Module): Module whose params are to be saved.
|
446 |
+
filename (str): Checkpoint filename.
|
447 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
448 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
449 |
+
"""
|
450 |
+
if meta is None:
|
451 |
+
meta = {}
|
452 |
+
elif not isinstance(meta, dict):
|
453 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
454 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
455 |
+
|
456 |
+
if is_module_wrapper(model):
|
457 |
+
model = model.module
|
458 |
+
|
459 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
460 |
+
# save class name to the meta
|
461 |
+
meta.update(CLASSES=model.CLASSES)
|
462 |
+
|
463 |
+
checkpoint = {
|
464 |
+
'meta': meta,
|
465 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
466 |
+
}
|
467 |
+
# save optimizer state dict in the checkpoint
|
468 |
+
if isinstance(optimizer, Optimizer):
|
469 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
470 |
+
elif isinstance(optimizer, dict):
|
471 |
+
checkpoint['optimizer'] = {}
|
472 |
+
for name, optim in optimizer.items():
|
473 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
474 |
+
|
475 |
+
if filename.startswith('pavi://'):
|
476 |
+
try:
|
477 |
+
from pavi import modelcloud
|
478 |
+
from pavi.exception import NodeNotFoundError
|
479 |
+
except ImportError:
|
480 |
+
raise ImportError(
|
481 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
482 |
+
model_path = filename[7:]
|
483 |
+
root = modelcloud.Folder()
|
484 |
+
model_dir, model_name = osp.split(model_path)
|
485 |
+
try:
|
486 |
+
model = modelcloud.get(model_dir)
|
487 |
+
except NodeNotFoundError:
|
488 |
+
model = root.create_training_model(model_dir)
|
489 |
+
with TemporaryDirectory() as tmp_dir:
|
490 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
491 |
+
with open(checkpoint_file, 'wb') as f:
|
492 |
+
torch.save(checkpoint, f)
|
493 |
+
f.flush()
|
494 |
+
model.create_file(checkpoint_file, name=model_name)
|
495 |
+
else:
|
496 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
497 |
+
# immediately flush buffer
|
498 |
+
with open(filename, 'wb') as f:
|
499 |
+
torch.save(checkpoint, f)
|
500 |
+
f.flush()
|
mmcv_custom/runner/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
from .checkpoint import save_checkpoint
|
3 |
+
from .epoch_based_runner import EpochBasedRunnerAmp
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'EpochBasedRunnerAmp', 'save_checkpoint'
|
8 |
+
]
|
mmcv_custom/runner/checkpoint.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import time
|
4 |
+
from tempfile import TemporaryDirectory
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmcv.parallel import is_module_wrapper
|
11 |
+
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
|
12 |
+
|
13 |
+
try:
|
14 |
+
import apex
|
15 |
+
except:
|
16 |
+
print('apex is not installed')
|
17 |
+
|
18 |
+
|
19 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
20 |
+
"""Save checkpoint to file.
|
21 |
+
|
22 |
+
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
|
23 |
+
``optimizer``, ``amp``. By default ``meta`` will contain version
|
24 |
+
and time info.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model (Module): Module whose params are to be saved.
|
28 |
+
filename (str): Checkpoint filename.
|
29 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
30 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
31 |
+
"""
|
32 |
+
if meta is None:
|
33 |
+
meta = {}
|
34 |
+
elif not isinstance(meta, dict):
|
35 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
36 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
37 |
+
|
38 |
+
if is_module_wrapper(model):
|
39 |
+
model = model.module
|
40 |
+
|
41 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
42 |
+
# save class name to the meta
|
43 |
+
meta.update(CLASSES=model.CLASSES)
|
44 |
+
|
45 |
+
checkpoint = {
|
46 |
+
'meta': meta,
|
47 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
48 |
+
}
|
49 |
+
# save optimizer state dict in the checkpoint
|
50 |
+
if isinstance(optimizer, Optimizer):
|
51 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
52 |
+
elif isinstance(optimizer, dict):
|
53 |
+
checkpoint['optimizer'] = {}
|
54 |
+
for name, optim in optimizer.items():
|
55 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
56 |
+
|
57 |
+
# save amp state dict in the checkpoint
|
58 |
+
checkpoint['amp'] = apex.amp.state_dict()
|
59 |
+
|
60 |
+
if filename.startswith('pavi://'):
|
61 |
+
try:
|
62 |
+
from pavi import modelcloud
|
63 |
+
from pavi.exception import NodeNotFoundError
|
64 |
+
except ImportError:
|
65 |
+
raise ImportError(
|
66 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
67 |
+
model_path = filename[7:]
|
68 |
+
root = modelcloud.Folder()
|
69 |
+
model_dir, model_name = osp.split(model_path)
|
70 |
+
try:
|
71 |
+
model = modelcloud.get(model_dir)
|
72 |
+
except NodeNotFoundError:
|
73 |
+
model = root.create_training_model(model_dir)
|
74 |
+
with TemporaryDirectory() as tmp_dir:
|
75 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
76 |
+
with open(checkpoint_file, 'wb') as f:
|
77 |
+
torch.save(checkpoint, f)
|
78 |
+
f.flush()
|
79 |
+
model.create_file(checkpoint_file, name=model_name)
|
80 |
+
else:
|
81 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
82 |
+
# immediately flush buffer
|
83 |
+
with open(filename, 'wb') as f:
|
84 |
+
torch.save(checkpoint, f)
|
85 |
+
f.flush()
|
mmcv_custom/runner/epoch_based_runner.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import platform
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmcv.runner import RUNNERS, EpochBasedRunner
|
11 |
+
from .checkpoint import save_checkpoint
|
12 |
+
|
13 |
+
try:
|
14 |
+
import apex
|
15 |
+
except:
|
16 |
+
print('apex is not installed')
|
17 |
+
|
18 |
+
|
19 |
+
@RUNNERS.register_module()
|
20 |
+
class EpochBasedRunnerAmp(EpochBasedRunner):
|
21 |
+
"""Epoch-based Runner with AMP support.
|
22 |
+
|
23 |
+
This runner train models epoch by epoch.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def save_checkpoint(self,
|
27 |
+
out_dir,
|
28 |
+
filename_tmpl='epoch_{}.pth',
|
29 |
+
save_optimizer=True,
|
30 |
+
meta=None,
|
31 |
+
create_symlink=True):
|
32 |
+
"""Save the checkpoint.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
out_dir (str): The directory that checkpoints are saved.
|
36 |
+
filename_tmpl (str, optional): The checkpoint filename template,
|
37 |
+
which contains a placeholder for the epoch number.
|
38 |
+
Defaults to 'epoch_{}.pth'.
|
39 |
+
save_optimizer (bool, optional): Whether to save the optimizer to
|
40 |
+
the checkpoint. Defaults to True.
|
41 |
+
meta (dict, optional): The meta information to be saved in the
|
42 |
+
checkpoint. Defaults to None.
|
43 |
+
create_symlink (bool, optional): Whether to create a symlink
|
44 |
+
"latest.pth" to point to the latest checkpoint.
|
45 |
+
Defaults to True.
|
46 |
+
"""
|
47 |
+
if meta is None:
|
48 |
+
meta = dict(epoch=self.epoch + 1, iter=self.iter)
|
49 |
+
elif isinstance(meta, dict):
|
50 |
+
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
51 |
+
else:
|
52 |
+
raise TypeError(
|
53 |
+
f'meta should be a dict or None, but got {type(meta)}')
|
54 |
+
if self.meta is not None:
|
55 |
+
meta.update(self.meta)
|
56 |
+
|
57 |
+
filename = filename_tmpl.format(self.epoch + 1)
|
58 |
+
filepath = osp.join(out_dir, filename)
|
59 |
+
optimizer = self.optimizer if save_optimizer else None
|
60 |
+
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
61 |
+
# in some environments, `os.symlink` is not supported, you may need to
|
62 |
+
# set `create_symlink` to False
|
63 |
+
if create_symlink:
|
64 |
+
dst_file = osp.join(out_dir, 'latest.pth')
|
65 |
+
if platform.system() != 'Windows':
|
66 |
+
mmcv.symlink(filename, dst_file)
|
67 |
+
else:
|
68 |
+
shutil.copy(filepath, dst_file)
|
69 |
+
|
70 |
+
def resume(self,
|
71 |
+
checkpoint,
|
72 |
+
resume_optimizer=True,
|
73 |
+
map_location='default'):
|
74 |
+
if map_location == 'default':
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
device_id = torch.cuda.current_device()
|
77 |
+
checkpoint = self.load_checkpoint(
|
78 |
+
checkpoint,
|
79 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
80 |
+
else:
|
81 |
+
checkpoint = self.load_checkpoint(checkpoint)
|
82 |
+
else:
|
83 |
+
checkpoint = self.load_checkpoint(
|
84 |
+
checkpoint, map_location=map_location)
|
85 |
+
|
86 |
+
self._epoch = checkpoint['meta']['epoch']
|
87 |
+
self._iter = checkpoint['meta']['iter']
|
88 |
+
if 'optimizer' in checkpoint and resume_optimizer:
|
89 |
+
if isinstance(self.optimizer, Optimizer):
|
90 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
91 |
+
elif isinstance(self.optimizer, dict):
|
92 |
+
for k in self.optimizer.keys():
|
93 |
+
self.optimizer[k].load_state_dict(
|
94 |
+
checkpoint['optimizer'][k])
|
95 |
+
else:
|
96 |
+
raise TypeError(
|
97 |
+
'Optimizer should be dict or torch.optim.Optimizer '
|
98 |
+
f'but got {type(self.optimizer)}')
|
99 |
+
|
100 |
+
if 'amp' in checkpoint:
|
101 |
+
apex.amp.load_state_dict(checkpoint['amp'])
|
102 |
+
self.logger.info('load amp state dict')
|
103 |
+
|
104 |
+
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
mmdet/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
|
3 |
+
from .version import __version__, short_version
|
4 |
+
|
5 |
+
|
6 |
+
def digit_version(version_str):
|
7 |
+
digit_version = []
|
8 |
+
for x in version_str.split('.'):
|
9 |
+
if x.isdigit():
|
10 |
+
digit_version.append(int(x))
|
11 |
+
elif x.find('rc') != -1:
|
12 |
+
patch_version = x.split('rc')
|
13 |
+
digit_version.append(int(patch_version[0]) - 1)
|
14 |
+
digit_version.append(int(patch_version[1]))
|
15 |
+
return digit_version
|
16 |
+
|
17 |
+
|
18 |
+
mmcv_minimum_version = '1.2.4'
|
19 |
+
mmcv_maximum_version = '1.4.0'
|
20 |
+
mmcv_version = digit_version(mmcv.__version__)
|
21 |
+
|
22 |
+
|
23 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
24 |
+
and mmcv_version <= digit_version(mmcv_maximum_version)), \
|
25 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
26 |
+
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
27 |
+
|
28 |
+
__all__ = ['__version__', 'short_version']
|
mmdet/apis/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import (async_inference_detector, inference_detector,
|
2 |
+
init_detector, show_result_pyplot)
|
3 |
+
from .test import multi_gpu_test, single_gpu_test
|
4 |
+
from .train import get_root_logger, set_random_seed, train_detector
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
|
8 |
+
'async_inference_detector', 'inference_detector', 'show_result_pyplot',
|
9 |
+
'multi_gpu_test', 'single_gpu_test'
|
10 |
+
]
|
mmdet/apis/inference.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.ops import RoIPool
|
7 |
+
from mmcv.parallel import collate, scatter
|
8 |
+
from mmcv.runner import load_checkpoint
|
9 |
+
|
10 |
+
from mmdet.core import get_classes
|
11 |
+
from mmdet.datasets import replace_ImageToTensor
|
12 |
+
from mmdet.datasets.pipelines import Compose
|
13 |
+
from mmdet.models import build_detector
|
14 |
+
|
15 |
+
|
16 |
+
def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
17 |
+
"""Initialize a detector from config file.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
config (str or :obj:`mmcv.Config`): Config file path or the config
|
21 |
+
object.
|
22 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
23 |
+
will not load any weights.
|
24 |
+
cfg_options (dict): Options to override some settings in the used
|
25 |
+
config.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
nn.Module: The constructed detector.
|
29 |
+
"""
|
30 |
+
if isinstance(config, str):
|
31 |
+
config = mmcv.Config.fromfile(config)
|
32 |
+
elif not isinstance(config, mmcv.Config):
|
33 |
+
raise TypeError('config must be a filename or Config object, '
|
34 |
+
f'but got {type(config)}')
|
35 |
+
if cfg_options is not None:
|
36 |
+
config.merge_from_dict(cfg_options)
|
37 |
+
config.model.pretrained = None
|
38 |
+
config.model.train_cfg = None
|
39 |
+
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
40 |
+
if checkpoint is not None:
|
41 |
+
map_loc = 'cpu' if device == 'cpu' else None
|
42 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
|
43 |
+
if 'CLASSES' in checkpoint.get('meta', {}):
|
44 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
45 |
+
else:
|
46 |
+
warnings.simplefilter('once')
|
47 |
+
warnings.warn('Class names are not saved in the checkpoint\'s '
|
48 |
+
'meta data, use COCO classes by default.')
|
49 |
+
model.CLASSES = get_classes('coco')
|
50 |
+
model.cfg = config # save the config in the model for convenience
|
51 |
+
model.to(device)
|
52 |
+
model.eval()
|
53 |
+
return model
|
54 |
+
|
55 |
+
|
56 |
+
class LoadImage(object):
|
57 |
+
"""Deprecated.
|
58 |
+
|
59 |
+
A simple pipeline to load image.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __call__(self, results):
|
63 |
+
"""Call function to load images into results.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
results (dict): A result dict contains the file name
|
67 |
+
of the image to be read.
|
68 |
+
Returns:
|
69 |
+
dict: ``results`` will be returned containing loaded image.
|
70 |
+
"""
|
71 |
+
warnings.simplefilter('once')
|
72 |
+
warnings.warn('`LoadImage` is deprecated and will be removed in '
|
73 |
+
'future releases. You may use `LoadImageFromWebcam` '
|
74 |
+
'from `mmdet.datasets.pipelines.` instead.')
|
75 |
+
if isinstance(results['img'], str):
|
76 |
+
results['filename'] = results['img']
|
77 |
+
results['ori_filename'] = results['img']
|
78 |
+
else:
|
79 |
+
results['filename'] = None
|
80 |
+
results['ori_filename'] = None
|
81 |
+
img = mmcv.imread(results['img'])
|
82 |
+
results['img'] = img
|
83 |
+
results['img_fields'] = ['img']
|
84 |
+
results['img_shape'] = img.shape
|
85 |
+
results['ori_shape'] = img.shape
|
86 |
+
return results
|
87 |
+
|
88 |
+
|
89 |
+
def inference_detector(model, imgs):
|
90 |
+
"""Inference image(s) with the detector.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model (nn.Module): The loaded detector.
|
94 |
+
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
95 |
+
Either image files or loaded images.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
If imgs is a list or tuple, the same length list type results
|
99 |
+
will be returned, otherwise return the detection results directly.
|
100 |
+
"""
|
101 |
+
|
102 |
+
if isinstance(imgs, (list, tuple)):
|
103 |
+
is_batch = True
|
104 |
+
else:
|
105 |
+
imgs = [imgs]
|
106 |
+
is_batch = False
|
107 |
+
|
108 |
+
cfg = model.cfg
|
109 |
+
device = next(model.parameters()).device # model device
|
110 |
+
|
111 |
+
if isinstance(imgs[0], np.ndarray):
|
112 |
+
cfg = cfg.copy()
|
113 |
+
# set loading pipeline type
|
114 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
115 |
+
|
116 |
+
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
117 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
118 |
+
|
119 |
+
datas = []
|
120 |
+
for img in imgs:
|
121 |
+
# prepare data
|
122 |
+
if isinstance(img, np.ndarray):
|
123 |
+
# directly add img
|
124 |
+
data = dict(img=img)
|
125 |
+
else:
|
126 |
+
# add information into dict
|
127 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
128 |
+
# build the data pipeline
|
129 |
+
data = test_pipeline(data)
|
130 |
+
datas.append(data)
|
131 |
+
|
132 |
+
data = collate(datas, samples_per_gpu=len(imgs))
|
133 |
+
# just get the actual data from DataContainer
|
134 |
+
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
135 |
+
data['img'] = [img.data[0] for img in data['img']]
|
136 |
+
if next(model.parameters()).is_cuda:
|
137 |
+
# scatter to specified GPU
|
138 |
+
data = scatter(data, [device])[0]
|
139 |
+
else:
|
140 |
+
for m in model.modules():
|
141 |
+
assert not isinstance(
|
142 |
+
m, RoIPool
|
143 |
+
), 'CPU inference with RoIPool is not supported currently.'
|
144 |
+
|
145 |
+
# forward the model
|
146 |
+
with torch.no_grad():
|
147 |
+
results = model(return_loss=False, rescale=True, **data)
|
148 |
+
|
149 |
+
if not is_batch:
|
150 |
+
return results[0]
|
151 |
+
else:
|
152 |
+
return results
|
153 |
+
|
154 |
+
|
155 |
+
async def async_inference_detector(model, img):
|
156 |
+
"""Async inference image(s) with the detector.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
model (nn.Module): The loaded detector.
|
160 |
+
img (str | ndarray): Either image files or loaded images.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
Awaitable detection results.
|
164 |
+
"""
|
165 |
+
cfg = model.cfg
|
166 |
+
device = next(model.parameters()).device # model device
|
167 |
+
# prepare data
|
168 |
+
if isinstance(img, np.ndarray):
|
169 |
+
# directly add img
|
170 |
+
data = dict(img=img)
|
171 |
+
cfg = cfg.copy()
|
172 |
+
# set loading pipeline type
|
173 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
174 |
+
else:
|
175 |
+
# add information into dict
|
176 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
177 |
+
# build the data pipeline
|
178 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
179 |
+
data = test_pipeline(data)
|
180 |
+
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
|
181 |
+
|
182 |
+
# We don't restore `torch.is_grad_enabled()` value during concurrent
|
183 |
+
# inference since execution can overlap
|
184 |
+
torch.set_grad_enabled(False)
|
185 |
+
result = await model.aforward_test(rescale=True, **data)
|
186 |
+
return result
|
187 |
+
|
188 |
+
|
189 |
+
def show_result_pyplot(model,
|
190 |
+
img,
|
191 |
+
result,
|
192 |
+
score_thr=0.3,
|
193 |
+
title='result',
|
194 |
+
wait_time=0):
|
195 |
+
"""Visualize the detection results on the image.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model (nn.Module): The loaded detector.
|
199 |
+
img (str or np.ndarray): Image filename or loaded image.
|
200 |
+
result (tuple[list] or list): The detection result, can be either
|
201 |
+
(bbox, segm) or just bbox.
|
202 |
+
score_thr (float): The threshold to visualize the bboxes and masks.
|
203 |
+
title (str): Title of the pyplot figure.
|
204 |
+
wait_time (float): Value of waitKey param.
|
205 |
+
Default: 0.
|
206 |
+
"""
|
207 |
+
if hasattr(model, 'module'):
|
208 |
+
model = model.module
|
209 |
+
model.show_result(
|
210 |
+
img,
|
211 |
+
result,
|
212 |
+
score_thr=score_thr,
|
213 |
+
show=True,
|
214 |
+
wait_time=wait_time,
|
215 |
+
win_name=title,
|
216 |
+
bbox_color=(72, 101, 241),
|
217 |
+
text_color=(72, 101, 241))
|
mmdet/apis/test.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import pickle
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import time
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from mmcv.image import tensor2imgs
|
11 |
+
from mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
from mmdet.core import encode_mask_results
|
14 |
+
|
15 |
+
|
16 |
+
def single_gpu_test(model,
|
17 |
+
data_loader,
|
18 |
+
show=False,
|
19 |
+
out_dir=None,
|
20 |
+
show_score_thr=0.3):
|
21 |
+
model.eval()
|
22 |
+
results = []
|
23 |
+
dataset = data_loader.dataset
|
24 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
25 |
+
for i, data in enumerate(data_loader):
|
26 |
+
with torch.no_grad():
|
27 |
+
result = model(return_loss=False, rescale=True, **data)
|
28 |
+
|
29 |
+
batch_size = len(result)
|
30 |
+
if show or out_dir:
|
31 |
+
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
|
32 |
+
img_tensor = data['img'][0]
|
33 |
+
else:
|
34 |
+
img_tensor = data['img'][0].data[0]
|
35 |
+
img_metas = data['img_metas'][0].data[0]
|
36 |
+
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
37 |
+
assert len(imgs) == len(img_metas)
|
38 |
+
|
39 |
+
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
|
40 |
+
h, w, _ = img_meta['img_shape']
|
41 |
+
img_show = img[:h, :w, :]
|
42 |
+
|
43 |
+
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
44 |
+
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
45 |
+
|
46 |
+
if out_dir:
|
47 |
+
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
48 |
+
else:
|
49 |
+
out_file = None
|
50 |
+
model.module.show_result(
|
51 |
+
img_show,
|
52 |
+
result[i],
|
53 |
+
show=show,
|
54 |
+
out_file=out_file,
|
55 |
+
score_thr=show_score_thr)
|
56 |
+
|
57 |
+
# encode mask results
|
58 |
+
if isinstance(result[0], tuple):
|
59 |
+
result = [(bbox_results, encode_mask_results(mask_results))
|
60 |
+
for bbox_results, mask_results in result]
|
61 |
+
results.extend(result)
|
62 |
+
|
63 |
+
for _ in range(batch_size):
|
64 |
+
prog_bar.update()
|
65 |
+
return results
|
66 |
+
|
67 |
+
|
68 |
+
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
69 |
+
"""Test model with multiple gpus.
|
70 |
+
|
71 |
+
This method tests model with multiple gpus and collects the results
|
72 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
73 |
+
it encodes results to gpu tensors and use gpu communication for results
|
74 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
75 |
+
and collects them by the rank 0 worker.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
model (nn.Module): Model to be tested.
|
79 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
80 |
+
tmpdir (str): Path of directory to save the temporary results from
|
81 |
+
different gpus under cpu mode.
|
82 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
list: The prediction results.
|
86 |
+
"""
|
87 |
+
model.eval()
|
88 |
+
results = []
|
89 |
+
dataset = data_loader.dataset
|
90 |
+
rank, world_size = get_dist_info()
|
91 |
+
if rank == 0:
|
92 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
93 |
+
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
94 |
+
for i, data in enumerate(data_loader):
|
95 |
+
with torch.no_grad():
|
96 |
+
result = model(return_loss=False, rescale=True, **data)
|
97 |
+
# encode mask results
|
98 |
+
if isinstance(result[0], tuple):
|
99 |
+
result = [(bbox_results, encode_mask_results(mask_results))
|
100 |
+
for bbox_results, mask_results in result]
|
101 |
+
results.extend(result)
|
102 |
+
|
103 |
+
if rank == 0:
|
104 |
+
batch_size = len(result)
|
105 |
+
for _ in range(batch_size * world_size):
|
106 |
+
prog_bar.update()
|
107 |
+
|
108 |
+
# collect results from all ranks
|
109 |
+
if gpu_collect:
|
110 |
+
results = collect_results_gpu(results, len(dataset))
|
111 |
+
else:
|
112 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
113 |
+
return results
|
114 |
+
|
115 |
+
|
116 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
117 |
+
rank, world_size = get_dist_info()
|
118 |
+
# create a tmp dir if it is not specified
|
119 |
+
if tmpdir is None:
|
120 |
+
MAX_LEN = 512
|
121 |
+
# 32 is whitespace
|
122 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
123 |
+
32,
|
124 |
+
dtype=torch.uint8,
|
125 |
+
device='cuda')
|
126 |
+
if rank == 0:
|
127 |
+
mmcv.mkdir_or_exist('.dist_test')
|
128 |
+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
129 |
+
tmpdir = torch.tensor(
|
130 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
131 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
132 |
+
dist.broadcast(dir_tensor, 0)
|
133 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
134 |
+
else:
|
135 |
+
mmcv.mkdir_or_exist(tmpdir)
|
136 |
+
# dump the part result to the dir
|
137 |
+
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
|
138 |
+
dist.barrier()
|
139 |
+
# collect all parts
|
140 |
+
if rank != 0:
|
141 |
+
return None
|
142 |
+
else:
|
143 |
+
# load results of all parts from tmp dir
|
144 |
+
part_list = []
|
145 |
+
for i in range(world_size):
|
146 |
+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
|
147 |
+
part_list.append(mmcv.load(part_file))
|
148 |
+
# sort the results
|
149 |
+
ordered_results = []
|
150 |
+
for res in zip(*part_list):
|
151 |
+
ordered_results.extend(list(res))
|
152 |
+
# the dataloader may pad some samples
|
153 |
+
ordered_results = ordered_results[:size]
|
154 |
+
# remove tmp dir
|
155 |
+
shutil.rmtree(tmpdir)
|
156 |
+
return ordered_results
|
157 |
+
|
158 |
+
|
159 |
+
def collect_results_gpu(result_part, size):
|
160 |
+
rank, world_size = get_dist_info()
|
161 |
+
# dump result part to tensor with pickle
|
162 |
+
part_tensor = torch.tensor(
|
163 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
164 |
+
# gather all result part tensor shape
|
165 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
166 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
167 |
+
dist.all_gather(shape_list, shape_tensor)
|
168 |
+
# padding result part tensor to max length
|
169 |
+
shape_max = torch.tensor(shape_list).max()
|
170 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
171 |
+
part_send[:shape_tensor[0]] = part_tensor
|
172 |
+
part_recv_list = [
|
173 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
174 |
+
]
|
175 |
+
# gather all result part
|
176 |
+
dist.all_gather(part_recv_list, part_send)
|
177 |
+
|
178 |
+
if rank == 0:
|
179 |
+
part_list = []
|
180 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
181 |
+
part_list.append(
|
182 |
+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
183 |
+
# sort the results
|
184 |
+
ordered_results = []
|
185 |
+
for res in zip(*part_list):
|
186 |
+
ordered_results.extend(list(res))
|
187 |
+
# the dataloader may pad some samples
|
188 |
+
ordered_results = ordered_results[:size]
|
189 |
+
return ordered_results
|
mmdet/apis/train.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
7 |
+
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
8 |
+
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
9 |
+
build_runner)
|
10 |
+
from mmcv.utils import build_from_cfg
|
11 |
+
|
12 |
+
from mmdet.core import DistEvalHook, EvalHook
|
13 |
+
from mmdet.datasets import (build_dataloader, build_dataset,
|
14 |
+
replace_ImageToTensor)
|
15 |
+
from mmdet.utils import get_root_logger
|
16 |
+
from mmcv_custom.runner import EpochBasedRunnerAmp
|
17 |
+
try:
|
18 |
+
import apex
|
19 |
+
except:
|
20 |
+
print('apex is not installed')
|
21 |
+
|
22 |
+
|
23 |
+
def set_random_seed(seed, deterministic=False):
|
24 |
+
"""Set random seed.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
seed (int): Seed to be used.
|
28 |
+
deterministic (bool): Whether to set the deterministic option for
|
29 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
30 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
31 |
+
Default: False.
|
32 |
+
"""
|
33 |
+
random.seed(seed)
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
if deterministic:
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
torch.backends.cudnn.benchmark = False
|
40 |
+
|
41 |
+
|
42 |
+
def train_detector(model,
|
43 |
+
dataset,
|
44 |
+
cfg,
|
45 |
+
distributed=False,
|
46 |
+
validate=False,
|
47 |
+
timestamp=None,
|
48 |
+
meta=None):
|
49 |
+
logger = get_root_logger(cfg.log_level)
|
50 |
+
|
51 |
+
# prepare data loaders
|
52 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
53 |
+
if 'imgs_per_gpu' in cfg.data:
|
54 |
+
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
55 |
+
'Please use "samples_per_gpu" instead')
|
56 |
+
if 'samples_per_gpu' in cfg.data:
|
57 |
+
logger.warning(
|
58 |
+
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
59 |
+
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
60 |
+
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
61 |
+
else:
|
62 |
+
logger.warning(
|
63 |
+
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
64 |
+
f'{cfg.data.imgs_per_gpu} in this experiments')
|
65 |
+
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
66 |
+
|
67 |
+
data_loaders = [
|
68 |
+
build_dataloader(
|
69 |
+
ds,
|
70 |
+
cfg.data.samples_per_gpu,
|
71 |
+
cfg.data.workers_per_gpu,
|
72 |
+
# cfg.gpus will be ignored if distributed
|
73 |
+
len(cfg.gpu_ids),
|
74 |
+
dist=distributed,
|
75 |
+
seed=cfg.seed) for ds in dataset
|
76 |
+
]
|
77 |
+
|
78 |
+
# build optimizer
|
79 |
+
optimizer = build_optimizer(model, cfg.optimizer)
|
80 |
+
|
81 |
+
# use apex fp16 optimizer
|
82 |
+
if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
|
83 |
+
if cfg.optimizer_config.get("use_fp16", False):
|
84 |
+
model, optimizer = apex.amp.initialize(
|
85 |
+
model.cuda(), optimizer, opt_level="O1")
|
86 |
+
for m in model.modules():
|
87 |
+
if hasattr(m, "fp16_enabled"):
|
88 |
+
m.fp16_enabled = True
|
89 |
+
|
90 |
+
# put model on gpus
|
91 |
+
if distributed:
|
92 |
+
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
93 |
+
# Sets the `find_unused_parameters` parameter in
|
94 |
+
# torch.nn.parallel.DistributedDataParallel
|
95 |
+
model = MMDistributedDataParallel(
|
96 |
+
model.cuda(),
|
97 |
+
device_ids=[torch.cuda.current_device()],
|
98 |
+
broadcast_buffers=False,
|
99 |
+
find_unused_parameters=find_unused_parameters)
|
100 |
+
else:
|
101 |
+
model = MMDataParallel(
|
102 |
+
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
103 |
+
|
104 |
+
if 'runner' not in cfg:
|
105 |
+
cfg.runner = {
|
106 |
+
'type': 'EpochBasedRunner',
|
107 |
+
'max_epochs': cfg.total_epochs
|
108 |
+
}
|
109 |
+
warnings.warn(
|
110 |
+
'config is now expected to have a `runner` section, '
|
111 |
+
'please set `runner` in your config.', UserWarning)
|
112 |
+
else:
|
113 |
+
if 'total_epochs' in cfg:
|
114 |
+
assert cfg.total_epochs == cfg.runner.max_epochs
|
115 |
+
|
116 |
+
# build runner
|
117 |
+
runner = build_runner(
|
118 |
+
cfg.runner,
|
119 |
+
default_args=dict(
|
120 |
+
model=model,
|
121 |
+
optimizer=optimizer,
|
122 |
+
work_dir=cfg.work_dir,
|
123 |
+
logger=logger,
|
124 |
+
meta=meta))
|
125 |
+
|
126 |
+
# an ugly workaround to make .log and .log.json filenames the same
|
127 |
+
runner.timestamp = timestamp
|
128 |
+
|
129 |
+
# fp16 setting
|
130 |
+
fp16_cfg = cfg.get('fp16', None)
|
131 |
+
if fp16_cfg is not None:
|
132 |
+
optimizer_config = Fp16OptimizerHook(
|
133 |
+
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
134 |
+
elif distributed and 'type' not in cfg.optimizer_config:
|
135 |
+
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
136 |
+
else:
|
137 |
+
optimizer_config = cfg.optimizer_config
|
138 |
+
|
139 |
+
# register hooks
|
140 |
+
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
141 |
+
cfg.checkpoint_config, cfg.log_config,
|
142 |
+
cfg.get('momentum_config', None))
|
143 |
+
if distributed:
|
144 |
+
if isinstance(runner, EpochBasedRunner):
|
145 |
+
runner.register_hook(DistSamplerSeedHook())
|
146 |
+
|
147 |
+
# register eval hooks
|
148 |
+
if validate:
|
149 |
+
# Support batch_size > 1 in validation
|
150 |
+
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
|
151 |
+
if val_samples_per_gpu > 1:
|
152 |
+
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
153 |
+
cfg.data.val.pipeline = replace_ImageToTensor(
|
154 |
+
cfg.data.val.pipeline)
|
155 |
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
156 |
+
val_dataloader = build_dataloader(
|
157 |
+
val_dataset,
|
158 |
+
samples_per_gpu=val_samples_per_gpu,
|
159 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
160 |
+
dist=distributed,
|
161 |
+
shuffle=False)
|
162 |
+
eval_cfg = cfg.get('evaluation', {})
|
163 |
+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
164 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
165 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
166 |
+
|
167 |
+
# user-defined hooks
|
168 |
+
if cfg.get('custom_hooks', None):
|
169 |
+
custom_hooks = cfg.custom_hooks
|
170 |
+
assert isinstance(custom_hooks, list), \
|
171 |
+
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
172 |
+
for hook_cfg in cfg.custom_hooks:
|
173 |
+
assert isinstance(hook_cfg, dict), \
|
174 |
+
'Each item in custom_hooks expects dict type, but got ' \
|
175 |
+
f'{type(hook_cfg)}'
|
176 |
+
hook_cfg = hook_cfg.copy()
|
177 |
+
priority = hook_cfg.pop('priority', 'NORMAL')
|
178 |
+
hook = build_from_cfg(hook_cfg, HOOKS)
|
179 |
+
runner.register_hook(hook, priority=priority)
|
180 |
+
|
181 |
+
if cfg.resume_from:
|
182 |
+
runner.resume(cfg.resume_from)
|
183 |
+
elif cfg.load_from:
|
184 |
+
runner.load_checkpoint(cfg.load_from)
|
185 |
+
runner.run(data_loaders, cfg.workflow)
|
mmdet/core/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anchor import * # noqa: F401, F403
|
2 |
+
from .bbox import * # noqa: F401, F403
|
3 |
+
from .evaluation import * # noqa: F401, F403
|
4 |
+
from .export import * # noqa: F401, F403
|
5 |
+
from .mask import * # noqa: F401, F403
|
6 |
+
from .post_processing import * # noqa: F401, F403
|
7 |
+
from .utils import * # noqa: F401, F403
|
mmdet/core/anchor/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
|
2 |
+
YOLOAnchorGenerator)
|
3 |
+
from .builder import ANCHOR_GENERATORS, build_anchor_generator
|
4 |
+
from .point_generator import PointGenerator
|
5 |
+
from .utils import anchor_inside_flags, calc_region, images_to_levels
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
|
9 |
+
'PointGenerator', 'images_to_levels', 'calc_region',
|
10 |
+
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
|
11 |
+
]
|
mmdet/core/anchor/anchor_generator.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn.modules.utils import _pair
|
5 |
+
|
6 |
+
from .builder import ANCHOR_GENERATORS
|
7 |
+
|
8 |
+
|
9 |
+
@ANCHOR_GENERATORS.register_module()
|
10 |
+
class AnchorGenerator(object):
|
11 |
+
"""Standard anchor generator for 2D anchor-based detectors.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
15 |
+
in multiple feature levels in order (w, h).
|
16 |
+
ratios (list[float]): The list of ratios between the height and width
|
17 |
+
of anchors in a single level.
|
18 |
+
scales (list[int] | None): Anchor scales for anchors in a single level.
|
19 |
+
It cannot be set at the same time if `octave_base_scale` and
|
20 |
+
`scales_per_octave` are set.
|
21 |
+
base_sizes (list[int] | None): The basic sizes
|
22 |
+
of anchors in multiple levels.
|
23 |
+
If None is given, strides will be used as base_sizes.
|
24 |
+
(If strides are non square, the shortest stride is taken.)
|
25 |
+
scale_major (bool): Whether to multiply scales first when generating
|
26 |
+
base anchors. If true, the anchors in the same row will have the
|
27 |
+
same scales. By default it is True in V2.0
|
28 |
+
octave_base_scale (int): The base scale of octave.
|
29 |
+
scales_per_octave (int): Number of scales for each octave.
|
30 |
+
`octave_base_scale` and `scales_per_octave` are usually used in
|
31 |
+
retinanet and the `scales` should be None when they are set.
|
32 |
+
centers (list[tuple[float, float]] | None): The centers of the anchor
|
33 |
+
relative to the feature grid center in multiple feature levels.
|
34 |
+
By default it is set to be None and not used. If a list of tuple of
|
35 |
+
float is given, they will be used to shift the centers of anchors.
|
36 |
+
center_offset (float): The offset of center in proportion to anchors'
|
37 |
+
width and height. By default it is 0 in V2.0.
|
38 |
+
|
39 |
+
Examples:
|
40 |
+
>>> from mmdet.core import AnchorGenerator
|
41 |
+
>>> self = AnchorGenerator([16], [1.], [1.], [9])
|
42 |
+
>>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
|
43 |
+
>>> print(all_anchors)
|
44 |
+
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
|
45 |
+
[11.5000, -4.5000, 20.5000, 4.5000],
|
46 |
+
[-4.5000, 11.5000, 4.5000, 20.5000],
|
47 |
+
[11.5000, 11.5000, 20.5000, 20.5000]])]
|
48 |
+
>>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
|
49 |
+
>>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
|
50 |
+
>>> print(all_anchors)
|
51 |
+
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
|
52 |
+
[11.5000, -4.5000, 20.5000, 4.5000],
|
53 |
+
[-4.5000, 11.5000, 4.5000, 20.5000],
|
54 |
+
[11.5000, 11.5000, 20.5000, 20.5000]]), \
|
55 |
+
tensor([[-9., -9., 9., 9.]])]
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
strides,
|
60 |
+
ratios,
|
61 |
+
scales=None,
|
62 |
+
base_sizes=None,
|
63 |
+
scale_major=True,
|
64 |
+
octave_base_scale=None,
|
65 |
+
scales_per_octave=None,
|
66 |
+
centers=None,
|
67 |
+
center_offset=0.):
|
68 |
+
# check center and center_offset
|
69 |
+
if center_offset != 0:
|
70 |
+
assert centers is None, 'center cannot be set when center_offset' \
|
71 |
+
f'!=0, {centers} is given.'
|
72 |
+
if not (0 <= center_offset <= 1):
|
73 |
+
raise ValueError('center_offset should be in range [0, 1], '
|
74 |
+
f'{center_offset} is given.')
|
75 |
+
if centers is not None:
|
76 |
+
assert len(centers) == len(strides), \
|
77 |
+
'The number of strides should be the same as centers, got ' \
|
78 |
+
f'{strides} and {centers}'
|
79 |
+
|
80 |
+
# calculate base sizes of anchors
|
81 |
+
self.strides = [_pair(stride) for stride in strides]
|
82 |
+
self.base_sizes = [min(stride) for stride in self.strides
|
83 |
+
] if base_sizes is None else base_sizes
|
84 |
+
assert len(self.base_sizes) == len(self.strides), \
|
85 |
+
'The number of strides should be the same as base sizes, got ' \
|
86 |
+
f'{self.strides} and {self.base_sizes}'
|
87 |
+
|
88 |
+
# calculate scales of anchors
|
89 |
+
assert ((octave_base_scale is not None
|
90 |
+
and scales_per_octave is not None) ^ (scales is not None)), \
|
91 |
+
'scales and octave_base_scale with scales_per_octave cannot' \
|
92 |
+
' be set at the same time'
|
93 |
+
if scales is not None:
|
94 |
+
self.scales = torch.Tensor(scales)
|
95 |
+
elif octave_base_scale is not None and scales_per_octave is not None:
|
96 |
+
octave_scales = np.array(
|
97 |
+
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
|
98 |
+
scales = octave_scales * octave_base_scale
|
99 |
+
self.scales = torch.Tensor(scales)
|
100 |
+
else:
|
101 |
+
raise ValueError('Either scales or octave_base_scale with '
|
102 |
+
'scales_per_octave should be set')
|
103 |
+
|
104 |
+
self.octave_base_scale = octave_base_scale
|
105 |
+
self.scales_per_octave = scales_per_octave
|
106 |
+
self.ratios = torch.Tensor(ratios)
|
107 |
+
self.scale_major = scale_major
|
108 |
+
self.centers = centers
|
109 |
+
self.center_offset = center_offset
|
110 |
+
self.base_anchors = self.gen_base_anchors()
|
111 |
+
|
112 |
+
@property
|
113 |
+
def num_base_anchors(self):
|
114 |
+
"""list[int]: total number of base anchors in a feature grid"""
|
115 |
+
return [base_anchors.size(0) for base_anchors in self.base_anchors]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def num_levels(self):
|
119 |
+
"""int: number of feature levels that the generator will be applied"""
|
120 |
+
return len(self.strides)
|
121 |
+
|
122 |
+
def gen_base_anchors(self):
|
123 |
+
"""Generate base anchors.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
127 |
+
feature levels.
|
128 |
+
"""
|
129 |
+
multi_level_base_anchors = []
|
130 |
+
for i, base_size in enumerate(self.base_sizes):
|
131 |
+
center = None
|
132 |
+
if self.centers is not None:
|
133 |
+
center = self.centers[i]
|
134 |
+
multi_level_base_anchors.append(
|
135 |
+
self.gen_single_level_base_anchors(
|
136 |
+
base_size,
|
137 |
+
scales=self.scales,
|
138 |
+
ratios=self.ratios,
|
139 |
+
center=center))
|
140 |
+
return multi_level_base_anchors
|
141 |
+
|
142 |
+
def gen_single_level_base_anchors(self,
|
143 |
+
base_size,
|
144 |
+
scales,
|
145 |
+
ratios,
|
146 |
+
center=None):
|
147 |
+
"""Generate base anchors of a single level.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
base_size (int | float): Basic size of an anchor.
|
151 |
+
scales (torch.Tensor): Scales of the anchor.
|
152 |
+
ratios (torch.Tensor): The ratio between between the height
|
153 |
+
and width of anchors in a single level.
|
154 |
+
center (tuple[float], optional): The center of the base anchor
|
155 |
+
related to a single feature grid. Defaults to None.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Anchors in a single-level feature maps.
|
159 |
+
"""
|
160 |
+
w = base_size
|
161 |
+
h = base_size
|
162 |
+
if center is None:
|
163 |
+
x_center = self.center_offset * w
|
164 |
+
y_center = self.center_offset * h
|
165 |
+
else:
|
166 |
+
x_center, y_center = center
|
167 |
+
|
168 |
+
h_ratios = torch.sqrt(ratios)
|
169 |
+
w_ratios = 1 / h_ratios
|
170 |
+
if self.scale_major:
|
171 |
+
ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
|
172 |
+
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
|
173 |
+
else:
|
174 |
+
ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
|
175 |
+
hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
|
176 |
+
|
177 |
+
# use float anchor and the anchor's center is aligned with the
|
178 |
+
# pixel center
|
179 |
+
base_anchors = [
|
180 |
+
x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
|
181 |
+
y_center + 0.5 * hs
|
182 |
+
]
|
183 |
+
base_anchors = torch.stack(base_anchors, dim=-1)
|
184 |
+
|
185 |
+
return base_anchors
|
186 |
+
|
187 |
+
def _meshgrid(self, x, y, row_major=True):
|
188 |
+
"""Generate mesh grid of x and y.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
x (torch.Tensor): Grids of x dimension.
|
192 |
+
y (torch.Tensor): Grids of y dimension.
|
193 |
+
row_major (bool, optional): Whether to return y grids first.
|
194 |
+
Defaults to True.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
tuple[torch.Tensor]: The mesh grids of x and y.
|
198 |
+
"""
|
199 |
+
# use shape instead of len to keep tracing while exporting to onnx
|
200 |
+
xx = x.repeat(y.shape[0])
|
201 |
+
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
|
202 |
+
if row_major:
|
203 |
+
return xx, yy
|
204 |
+
else:
|
205 |
+
return yy, xx
|
206 |
+
|
207 |
+
def grid_anchors(self, featmap_sizes, device='cuda'):
|
208 |
+
"""Generate grid anchors in multiple feature levels.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
featmap_sizes (list[tuple]): List of feature map sizes in
|
212 |
+
multiple feature levels.
|
213 |
+
device (str): Device where the anchors will be put on.
|
214 |
+
|
215 |
+
Return:
|
216 |
+
list[torch.Tensor]: Anchors in multiple feature levels. \
|
217 |
+
The sizes of each tensor should be [N, 4], where \
|
218 |
+
N = width * height * num_base_anchors, width and height \
|
219 |
+
are the sizes of the corresponding feature level, \
|
220 |
+
num_base_anchors is the number of anchors for that level.
|
221 |
+
"""
|
222 |
+
assert self.num_levels == len(featmap_sizes)
|
223 |
+
multi_level_anchors = []
|
224 |
+
for i in range(self.num_levels):
|
225 |
+
anchors = self.single_level_grid_anchors(
|
226 |
+
self.base_anchors[i].to(device),
|
227 |
+
featmap_sizes[i],
|
228 |
+
self.strides[i],
|
229 |
+
device=device)
|
230 |
+
multi_level_anchors.append(anchors)
|
231 |
+
return multi_level_anchors
|
232 |
+
|
233 |
+
def single_level_grid_anchors(self,
|
234 |
+
base_anchors,
|
235 |
+
featmap_size,
|
236 |
+
stride=(16, 16),
|
237 |
+
device='cuda'):
|
238 |
+
"""Generate grid anchors of a single level.
|
239 |
+
|
240 |
+
Note:
|
241 |
+
This function is usually called by method ``self.grid_anchors``.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
base_anchors (torch.Tensor): The base anchors of a feature grid.
|
245 |
+
featmap_size (tuple[int]): Size of the feature maps.
|
246 |
+
stride (tuple[int], optional): Stride of the feature map in order
|
247 |
+
(w, h). Defaults to (16, 16).
|
248 |
+
device (str, optional): Device the tensor will be put on.
|
249 |
+
Defaults to 'cuda'.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
torch.Tensor: Anchors in the overall feature maps.
|
253 |
+
"""
|
254 |
+
# keep as Tensor, so that we can covert to ONNX correctly
|
255 |
+
feat_h, feat_w = featmap_size
|
256 |
+
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
|
257 |
+
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
|
258 |
+
|
259 |
+
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
260 |
+
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
|
261 |
+
shifts = shifts.type_as(base_anchors)
|
262 |
+
# first feat_w elements correspond to the first row of shifts
|
263 |
+
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
|
264 |
+
# shifted anchors (K, A, 4), reshape to (K*A, 4)
|
265 |
+
|
266 |
+
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
267 |
+
all_anchors = all_anchors.view(-1, 4)
|
268 |
+
# first A rows correspond to A anchors of (0, 0) in feature map,
|
269 |
+
# then (0, 1), (0, 2), ...
|
270 |
+
return all_anchors
|
271 |
+
|
272 |
+
def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
|
273 |
+
"""Generate valid flags of anchors in multiple feature levels.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
featmap_sizes (list(tuple)): List of feature map sizes in
|
277 |
+
multiple feature levels.
|
278 |
+
pad_shape (tuple): The padded shape of the image.
|
279 |
+
device (str): Device where the anchors will be put on.
|
280 |
+
|
281 |
+
Return:
|
282 |
+
list(torch.Tensor): Valid flags of anchors in multiple levels.
|
283 |
+
"""
|
284 |
+
assert self.num_levels == len(featmap_sizes)
|
285 |
+
multi_level_flags = []
|
286 |
+
for i in range(self.num_levels):
|
287 |
+
anchor_stride = self.strides[i]
|
288 |
+
feat_h, feat_w = featmap_sizes[i]
|
289 |
+
h, w = pad_shape[:2]
|
290 |
+
valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
|
291 |
+
valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
|
292 |
+
flags = self.single_level_valid_flags((feat_h, feat_w),
|
293 |
+
(valid_feat_h, valid_feat_w),
|
294 |
+
self.num_base_anchors[i],
|
295 |
+
device=device)
|
296 |
+
multi_level_flags.append(flags)
|
297 |
+
return multi_level_flags
|
298 |
+
|
299 |
+
def single_level_valid_flags(self,
|
300 |
+
featmap_size,
|
301 |
+
valid_size,
|
302 |
+
num_base_anchors,
|
303 |
+
device='cuda'):
|
304 |
+
"""Generate the valid flags of anchor in a single feature map.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
featmap_size (tuple[int]): The size of feature maps.
|
308 |
+
valid_size (tuple[int]): The valid size of the feature maps.
|
309 |
+
num_base_anchors (int): The number of base anchors.
|
310 |
+
device (str, optional): Device where the flags will be put on.
|
311 |
+
Defaults to 'cuda'.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
torch.Tensor: The valid flags of each anchor in a single level \
|
315 |
+
feature map.
|
316 |
+
"""
|
317 |
+
feat_h, feat_w = featmap_size
|
318 |
+
valid_h, valid_w = valid_size
|
319 |
+
assert valid_h <= feat_h and valid_w <= feat_w
|
320 |
+
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
321 |
+
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
322 |
+
valid_x[:valid_w] = 1
|
323 |
+
valid_y[:valid_h] = 1
|
324 |
+
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
325 |
+
valid = valid_xx & valid_yy
|
326 |
+
valid = valid[:, None].expand(valid.size(0),
|
327 |
+
num_base_anchors).contiguous().view(-1)
|
328 |
+
return valid
|
329 |
+
|
330 |
+
def __repr__(self):
|
331 |
+
"""str: a string that describes the module"""
|
332 |
+
indent_str = ' '
|
333 |
+
repr_str = self.__class__.__name__ + '(\n'
|
334 |
+
repr_str += f'{indent_str}strides={self.strides},\n'
|
335 |
+
repr_str += f'{indent_str}ratios={self.ratios},\n'
|
336 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
337 |
+
repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
|
338 |
+
repr_str += f'{indent_str}scale_major={self.scale_major},\n'
|
339 |
+
repr_str += f'{indent_str}octave_base_scale='
|
340 |
+
repr_str += f'{self.octave_base_scale},\n'
|
341 |
+
repr_str += f'{indent_str}scales_per_octave='
|
342 |
+
repr_str += f'{self.scales_per_octave},\n'
|
343 |
+
repr_str += f'{indent_str}num_levels={self.num_levels}\n'
|
344 |
+
repr_str += f'{indent_str}centers={self.centers},\n'
|
345 |
+
repr_str += f'{indent_str}center_offset={self.center_offset})'
|
346 |
+
return repr_str
|
347 |
+
|
348 |
+
|
349 |
+
@ANCHOR_GENERATORS.register_module()
|
350 |
+
class SSDAnchorGenerator(AnchorGenerator):
|
351 |
+
"""Anchor generator for SSD.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
355 |
+
in multiple feature levels.
|
356 |
+
ratios (list[float]): The list of ratios between the height and width
|
357 |
+
of anchors in a single level.
|
358 |
+
basesize_ratio_range (tuple(float)): Ratio range of anchors.
|
359 |
+
input_size (int): Size of feature map, 300 for SSD300,
|
360 |
+
512 for SSD512.
|
361 |
+
scale_major (bool): Whether to multiply scales first when generating
|
362 |
+
base anchors. If true, the anchors in the same row will have the
|
363 |
+
same scales. It is always set to be False in SSD.
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self,
|
367 |
+
strides,
|
368 |
+
ratios,
|
369 |
+
basesize_ratio_range,
|
370 |
+
input_size=300,
|
371 |
+
scale_major=True):
|
372 |
+
assert len(strides) == len(ratios)
|
373 |
+
assert mmcv.is_tuple_of(basesize_ratio_range, float)
|
374 |
+
|
375 |
+
self.strides = [_pair(stride) for stride in strides]
|
376 |
+
self.input_size = input_size
|
377 |
+
self.centers = [(stride[0] / 2., stride[1] / 2.)
|
378 |
+
for stride in self.strides]
|
379 |
+
self.basesize_ratio_range = basesize_ratio_range
|
380 |
+
|
381 |
+
# calculate anchor ratios and sizes
|
382 |
+
min_ratio, max_ratio = basesize_ratio_range
|
383 |
+
min_ratio = int(min_ratio * 100)
|
384 |
+
max_ratio = int(max_ratio * 100)
|
385 |
+
step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
|
386 |
+
min_sizes = []
|
387 |
+
max_sizes = []
|
388 |
+
for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
|
389 |
+
min_sizes.append(int(self.input_size * ratio / 100))
|
390 |
+
max_sizes.append(int(self.input_size * (ratio + step) / 100))
|
391 |
+
if self.input_size == 300:
|
392 |
+
if basesize_ratio_range[0] == 0.15: # SSD300 COCO
|
393 |
+
min_sizes.insert(0, int(self.input_size * 7 / 100))
|
394 |
+
max_sizes.insert(0, int(self.input_size * 15 / 100))
|
395 |
+
elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
|
396 |
+
min_sizes.insert(0, int(self.input_size * 10 / 100))
|
397 |
+
max_sizes.insert(0, int(self.input_size * 20 / 100))
|
398 |
+
else:
|
399 |
+
raise ValueError(
|
400 |
+
'basesize_ratio_range[0] should be either 0.15'
|
401 |
+
'or 0.2 when input_size is 300, got '
|
402 |
+
f'{basesize_ratio_range[0]}.')
|
403 |
+
elif self.input_size == 512:
|
404 |
+
if basesize_ratio_range[0] == 0.1: # SSD512 COCO
|
405 |
+
min_sizes.insert(0, int(self.input_size * 4 / 100))
|
406 |
+
max_sizes.insert(0, int(self.input_size * 10 / 100))
|
407 |
+
elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
|
408 |
+
min_sizes.insert(0, int(self.input_size * 7 / 100))
|
409 |
+
max_sizes.insert(0, int(self.input_size * 15 / 100))
|
410 |
+
else:
|
411 |
+
raise ValueError('basesize_ratio_range[0] should be either 0.1'
|
412 |
+
'or 0.15 when input_size is 512, got'
|
413 |
+
f' {basesize_ratio_range[0]}.')
|
414 |
+
else:
|
415 |
+
raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
|
416 |
+
f', got {self.input_size}.')
|
417 |
+
|
418 |
+
anchor_ratios = []
|
419 |
+
anchor_scales = []
|
420 |
+
for k in range(len(self.strides)):
|
421 |
+
scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
|
422 |
+
anchor_ratio = [1.]
|
423 |
+
for r in ratios[k]:
|
424 |
+
anchor_ratio += [1 / r, r] # 4 or 6 ratio
|
425 |
+
anchor_ratios.append(torch.Tensor(anchor_ratio))
|
426 |
+
anchor_scales.append(torch.Tensor(scales))
|
427 |
+
|
428 |
+
self.base_sizes = min_sizes
|
429 |
+
self.scales = anchor_scales
|
430 |
+
self.ratios = anchor_ratios
|
431 |
+
self.scale_major = scale_major
|
432 |
+
self.center_offset = 0
|
433 |
+
self.base_anchors = self.gen_base_anchors()
|
434 |
+
|
435 |
+
def gen_base_anchors(self):
|
436 |
+
"""Generate base anchors.
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
440 |
+
feature levels.
|
441 |
+
"""
|
442 |
+
multi_level_base_anchors = []
|
443 |
+
for i, base_size in enumerate(self.base_sizes):
|
444 |
+
base_anchors = self.gen_single_level_base_anchors(
|
445 |
+
base_size,
|
446 |
+
scales=self.scales[i],
|
447 |
+
ratios=self.ratios[i],
|
448 |
+
center=self.centers[i])
|
449 |
+
indices = list(range(len(self.ratios[i])))
|
450 |
+
indices.insert(1, len(indices))
|
451 |
+
base_anchors = torch.index_select(base_anchors, 0,
|
452 |
+
torch.LongTensor(indices))
|
453 |
+
multi_level_base_anchors.append(base_anchors)
|
454 |
+
return multi_level_base_anchors
|
455 |
+
|
456 |
+
def __repr__(self):
|
457 |
+
"""str: a string that describes the module"""
|
458 |
+
indent_str = ' '
|
459 |
+
repr_str = self.__class__.__name__ + '(\n'
|
460 |
+
repr_str += f'{indent_str}strides={self.strides},\n'
|
461 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
462 |
+
repr_str += f'{indent_str}scale_major={self.scale_major},\n'
|
463 |
+
repr_str += f'{indent_str}input_size={self.input_size},\n'
|
464 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
465 |
+
repr_str += f'{indent_str}ratios={self.ratios},\n'
|
466 |
+
repr_str += f'{indent_str}num_levels={self.num_levels},\n'
|
467 |
+
repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
|
468 |
+
repr_str += f'{indent_str}basesize_ratio_range='
|
469 |
+
repr_str += f'{self.basesize_ratio_range})'
|
470 |
+
return repr_str
|
471 |
+
|
472 |
+
|
473 |
+
@ANCHOR_GENERATORS.register_module()
|
474 |
+
class LegacyAnchorGenerator(AnchorGenerator):
|
475 |
+
"""Legacy anchor generator used in MMDetection V1.x.
|
476 |
+
|
477 |
+
Note:
|
478 |
+
Difference to the V2.0 anchor generator:
|
479 |
+
|
480 |
+
1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
|
481 |
+
2. The width/height are minused by 1 when calculating the anchors' \
|
482 |
+
centers and corners to meet the V1.x coordinate system.
|
483 |
+
3. The anchors' corners are quantized.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
strides (list[int] | list[tuple[int]]): Strides of anchors
|
487 |
+
in multiple feature levels.
|
488 |
+
ratios (list[float]): The list of ratios between the height and width
|
489 |
+
of anchors in a single level.
|
490 |
+
scales (list[int] | None): Anchor scales for anchors in a single level.
|
491 |
+
It cannot be set at the same time if `octave_base_scale` and
|
492 |
+
`scales_per_octave` are set.
|
493 |
+
base_sizes (list[int]): The basic sizes of anchors in multiple levels.
|
494 |
+
If None is given, strides will be used to generate base_sizes.
|
495 |
+
scale_major (bool): Whether to multiply scales first when generating
|
496 |
+
base anchors. If true, the anchors in the same row will have the
|
497 |
+
same scales. By default it is True in V2.0
|
498 |
+
octave_base_scale (int): The base scale of octave.
|
499 |
+
scales_per_octave (int): Number of scales for each octave.
|
500 |
+
`octave_base_scale` and `scales_per_octave` are usually used in
|
501 |
+
retinanet and the `scales` should be None when they are set.
|
502 |
+
centers (list[tuple[float, float]] | None): The centers of the anchor
|
503 |
+
relative to the feature grid center in multiple feature levels.
|
504 |
+
By default it is set to be None and not used. It a list of float
|
505 |
+
is given, this list will be used to shift the centers of anchors.
|
506 |
+
center_offset (float): The offset of center in propotion to anchors'
|
507 |
+
width and height. By default it is 0.5 in V2.0 but it should be 0.5
|
508 |
+
in v1.x models.
|
509 |
+
|
510 |
+
Examples:
|
511 |
+
>>> from mmdet.core import LegacyAnchorGenerator
|
512 |
+
>>> self = LegacyAnchorGenerator(
|
513 |
+
>>> [16], [1.], [1.], [9], center_offset=0.5)
|
514 |
+
>>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
|
515 |
+
>>> print(all_anchors)
|
516 |
+
[tensor([[ 0., 0., 8., 8.],
|
517 |
+
[16., 0., 24., 8.],
|
518 |
+
[ 0., 16., 8., 24.],
|
519 |
+
[16., 16., 24., 24.]])]
|
520 |
+
"""
|
521 |
+
|
522 |
+
def gen_single_level_base_anchors(self,
|
523 |
+
base_size,
|
524 |
+
scales,
|
525 |
+
ratios,
|
526 |
+
center=None):
|
527 |
+
"""Generate base anchors of a single level.
|
528 |
+
|
529 |
+
Note:
|
530 |
+
The width/height of anchors are minused by 1 when calculating \
|
531 |
+
the centers and corners to meet the V1.x coordinate system.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
base_size (int | float): Basic size of an anchor.
|
535 |
+
scales (torch.Tensor): Scales of the anchor.
|
536 |
+
ratios (torch.Tensor): The ratio between between the height.
|
537 |
+
and width of anchors in a single level.
|
538 |
+
center (tuple[float], optional): The center of the base anchor
|
539 |
+
related to a single feature grid. Defaults to None.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
torch.Tensor: Anchors in a single-level feature map.
|
543 |
+
"""
|
544 |
+
w = base_size
|
545 |
+
h = base_size
|
546 |
+
if center is None:
|
547 |
+
x_center = self.center_offset * (w - 1)
|
548 |
+
y_center = self.center_offset * (h - 1)
|
549 |
+
else:
|
550 |
+
x_center, y_center = center
|
551 |
+
|
552 |
+
h_ratios = torch.sqrt(ratios)
|
553 |
+
w_ratios = 1 / h_ratios
|
554 |
+
if self.scale_major:
|
555 |
+
ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
|
556 |
+
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
|
557 |
+
else:
|
558 |
+
ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
|
559 |
+
hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
|
560 |
+
|
561 |
+
# use float anchor and the anchor's center is aligned with the
|
562 |
+
# pixel center
|
563 |
+
base_anchors = [
|
564 |
+
x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
|
565 |
+
x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
|
566 |
+
]
|
567 |
+
base_anchors = torch.stack(base_anchors, dim=-1).round()
|
568 |
+
|
569 |
+
return base_anchors
|
570 |
+
|
571 |
+
|
572 |
+
@ANCHOR_GENERATORS.register_module()
|
573 |
+
class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
|
574 |
+
"""Legacy anchor generator used in MMDetection V1.x.
|
575 |
+
|
576 |
+
The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
|
577 |
+
can be found in `LegacyAnchorGenerator`.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self,
|
581 |
+
strides,
|
582 |
+
ratios,
|
583 |
+
basesize_ratio_range,
|
584 |
+
input_size=300,
|
585 |
+
scale_major=True):
|
586 |
+
super(LegacySSDAnchorGenerator,
|
587 |
+
self).__init__(strides, ratios, basesize_ratio_range, input_size,
|
588 |
+
scale_major)
|
589 |
+
self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
|
590 |
+
for stride in strides]
|
591 |
+
self.base_anchors = self.gen_base_anchors()
|
592 |
+
|
593 |
+
|
594 |
+
@ANCHOR_GENERATORS.register_module()
|
595 |
+
class YOLOAnchorGenerator(AnchorGenerator):
|
596 |
+
"""Anchor generator for YOLO.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
600 |
+
in multiple feature levels.
|
601 |
+
base_sizes (list[list[tuple[int, int]]]): The basic sizes
|
602 |
+
of anchors in multiple levels.
|
603 |
+
"""
|
604 |
+
|
605 |
+
def __init__(self, strides, base_sizes):
|
606 |
+
self.strides = [_pair(stride) for stride in strides]
|
607 |
+
self.centers = [(stride[0] / 2., stride[1] / 2.)
|
608 |
+
for stride in self.strides]
|
609 |
+
self.base_sizes = []
|
610 |
+
num_anchor_per_level = len(base_sizes[0])
|
611 |
+
for base_sizes_per_level in base_sizes:
|
612 |
+
assert num_anchor_per_level == len(base_sizes_per_level)
|
613 |
+
self.base_sizes.append(
|
614 |
+
[_pair(base_size) for base_size in base_sizes_per_level])
|
615 |
+
self.base_anchors = self.gen_base_anchors()
|
616 |
+
|
617 |
+
@property
|
618 |
+
def num_levels(self):
|
619 |
+
"""int: number of feature levels that the generator will be applied"""
|
620 |
+
return len(self.base_sizes)
|
621 |
+
|
622 |
+
def gen_base_anchors(self):
|
623 |
+
"""Generate base anchors.
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
627 |
+
feature levels.
|
628 |
+
"""
|
629 |
+
multi_level_base_anchors = []
|
630 |
+
for i, base_sizes_per_level in enumerate(self.base_sizes):
|
631 |
+
center = None
|
632 |
+
if self.centers is not None:
|
633 |
+
center = self.centers[i]
|
634 |
+
multi_level_base_anchors.append(
|
635 |
+
self.gen_single_level_base_anchors(base_sizes_per_level,
|
636 |
+
center))
|
637 |
+
return multi_level_base_anchors
|
638 |
+
|
639 |
+
def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
|
640 |
+
"""Generate base anchors of a single level.
|
641 |
+
|
642 |
+
Args:
|
643 |
+
base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
|
644 |
+
anchors.
|
645 |
+
center (tuple[float], optional): The center of the base anchor
|
646 |
+
related to a single feature grid. Defaults to None.
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
torch.Tensor: Anchors in a single-level feature maps.
|
650 |
+
"""
|
651 |
+
x_center, y_center = center
|
652 |
+
base_anchors = []
|
653 |
+
for base_size in base_sizes_per_level:
|
654 |
+
w, h = base_size
|
655 |
+
|
656 |
+
# use float anchor and the anchor's center is aligned with the
|
657 |
+
# pixel center
|
658 |
+
base_anchor = torch.Tensor([
|
659 |
+
x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
|
660 |
+
y_center + 0.5 * h
|
661 |
+
])
|
662 |
+
base_anchors.append(base_anchor)
|
663 |
+
base_anchors = torch.stack(base_anchors, dim=0)
|
664 |
+
|
665 |
+
return base_anchors
|
666 |
+
|
667 |
+
def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
|
668 |
+
"""Generate responsible anchor flags of grid cells in multiple scales.
|
669 |
+
|
670 |
+
Args:
|
671 |
+
featmap_sizes (list(tuple)): List of feature map sizes in multiple
|
672 |
+
feature levels.
|
673 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
|
674 |
+
device (str): Device where the anchors will be put on.
|
675 |
+
|
676 |
+
Return:
|
677 |
+
list(torch.Tensor): responsible flags of anchors in multiple level
|
678 |
+
"""
|
679 |
+
assert self.num_levels == len(featmap_sizes)
|
680 |
+
multi_level_responsible_flags = []
|
681 |
+
for i in range(self.num_levels):
|
682 |
+
anchor_stride = self.strides[i]
|
683 |
+
flags = self.single_level_responsible_flags(
|
684 |
+
featmap_sizes[i],
|
685 |
+
gt_bboxes,
|
686 |
+
anchor_stride,
|
687 |
+
self.num_base_anchors[i],
|
688 |
+
device=device)
|
689 |
+
multi_level_responsible_flags.append(flags)
|
690 |
+
return multi_level_responsible_flags
|
691 |
+
|
692 |
+
def single_level_responsible_flags(self,
|
693 |
+
featmap_size,
|
694 |
+
gt_bboxes,
|
695 |
+
stride,
|
696 |
+
num_base_anchors,
|
697 |
+
device='cuda'):
|
698 |
+
"""Generate the responsible flags of anchor in a single feature map.
|
699 |
+
|
700 |
+
Args:
|
701 |
+
featmap_size (tuple[int]): The size of feature maps.
|
702 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
|
703 |
+
stride (tuple(int)): stride of current level
|
704 |
+
num_base_anchors (int): The number of base anchors.
|
705 |
+
device (str, optional): Device where the flags will be put on.
|
706 |
+
Defaults to 'cuda'.
|
707 |
+
|
708 |
+
Returns:
|
709 |
+
torch.Tensor: The valid flags of each anchor in a single level \
|
710 |
+
feature map.
|
711 |
+
"""
|
712 |
+
feat_h, feat_w = featmap_size
|
713 |
+
gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
|
714 |
+
gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
|
715 |
+
gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
|
716 |
+
gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
|
717 |
+
|
718 |
+
# row major indexing
|
719 |
+
gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
|
720 |
+
|
721 |
+
responsible_grid = torch.zeros(
|
722 |
+
feat_h * feat_w, dtype=torch.uint8, device=device)
|
723 |
+
responsible_grid[gt_bboxes_grid_idx] = 1
|
724 |
+
|
725 |
+
responsible_grid = responsible_grid[:, None].expand(
|
726 |
+
responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
|
727 |
+
return responsible_grid
|
mmdet/core/anchor/builder.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.utils import Registry, build_from_cfg
|
2 |
+
|
3 |
+
ANCHOR_GENERATORS = Registry('Anchor generator')
|
4 |
+
|
5 |
+
|
6 |
+
def build_anchor_generator(cfg, default_args=None):
|
7 |
+
return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
|
mmdet/core/anchor/point_generator.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .builder import ANCHOR_GENERATORS
|
4 |
+
|
5 |
+
|
6 |
+
@ANCHOR_GENERATORS.register_module()
|
7 |
+
class PointGenerator(object):
|
8 |
+
|
9 |
+
def _meshgrid(self, x, y, row_major=True):
|
10 |
+
xx = x.repeat(len(y))
|
11 |
+
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
|
12 |
+
if row_major:
|
13 |
+
return xx, yy
|
14 |
+
else:
|
15 |
+
return yy, xx
|
16 |
+
|
17 |
+
def grid_points(self, featmap_size, stride=16, device='cuda'):
|
18 |
+
feat_h, feat_w = featmap_size
|
19 |
+
shift_x = torch.arange(0., feat_w, device=device) * stride
|
20 |
+
shift_y = torch.arange(0., feat_h, device=device) * stride
|
21 |
+
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
22 |
+
stride = shift_x.new_full((shift_xx.shape[0], ), stride)
|
23 |
+
shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
|
24 |
+
all_points = shifts.to(device)
|
25 |
+
return all_points
|
26 |
+
|
27 |
+
def valid_flags(self, featmap_size, valid_size, device='cuda'):
|
28 |
+
feat_h, feat_w = featmap_size
|
29 |
+
valid_h, valid_w = valid_size
|
30 |
+
assert valid_h <= feat_h and valid_w <= feat_w
|
31 |
+
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
32 |
+
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
33 |
+
valid_x[:valid_w] = 1
|
34 |
+
valid_y[:valid_h] = 1
|
35 |
+
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
36 |
+
valid = valid_xx & valid_yy
|
37 |
+
return valid
|
mmdet/core/anchor/utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def images_to_levels(target, num_levels):
|
5 |
+
"""Convert targets by image to targets by feature level.
|
6 |
+
|
7 |
+
[target_img0, target_img1] -> [target_level0, target_level1, ...]
|
8 |
+
"""
|
9 |
+
target = torch.stack(target, 0)
|
10 |
+
level_targets = []
|
11 |
+
start = 0
|
12 |
+
for n in num_levels:
|
13 |
+
end = start + n
|
14 |
+
# level_targets.append(target[:, start:end].squeeze(0))
|
15 |
+
level_targets.append(target[:, start:end])
|
16 |
+
start = end
|
17 |
+
return level_targets
|
18 |
+
|
19 |
+
|
20 |
+
def anchor_inside_flags(flat_anchors,
|
21 |
+
valid_flags,
|
22 |
+
img_shape,
|
23 |
+
allowed_border=0):
|
24 |
+
"""Check whether the anchors are inside the border.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
|
28 |
+
valid_flags (torch.Tensor): An existing valid flags of anchors.
|
29 |
+
img_shape (tuple(int)): Shape of current image.
|
30 |
+
allowed_border (int, optional): The border to allow the valid anchor.
|
31 |
+
Defaults to 0.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
torch.Tensor: Flags indicating whether the anchors are inside a \
|
35 |
+
valid range.
|
36 |
+
"""
|
37 |
+
img_h, img_w = img_shape[:2]
|
38 |
+
if allowed_border >= 0:
|
39 |
+
inside_flags = valid_flags & \
|
40 |
+
(flat_anchors[:, 0] >= -allowed_border) & \
|
41 |
+
(flat_anchors[:, 1] >= -allowed_border) & \
|
42 |
+
(flat_anchors[:, 2] < img_w + allowed_border) & \
|
43 |
+
(flat_anchors[:, 3] < img_h + allowed_border)
|
44 |
+
else:
|
45 |
+
inside_flags = valid_flags
|
46 |
+
return inside_flags
|
47 |
+
|
48 |
+
|
49 |
+
def calc_region(bbox, ratio, featmap_size=None):
|
50 |
+
"""Calculate a proportional bbox region.
|
51 |
+
|
52 |
+
The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
|
56 |
+
ratio (float): Ratio of the output region.
|
57 |
+
featmap_size (tuple): Feature map size used for clipping the boundary.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tuple: x1, y1, x2, y2
|
61 |
+
"""
|
62 |
+
x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
|
63 |
+
y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
|
64 |
+
x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
|
65 |
+
y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
|
66 |
+
if featmap_size is not None:
|
67 |
+
x1 = x1.clamp(min=0, max=featmap_size[1])
|
68 |
+
y1 = y1.clamp(min=0, max=featmap_size[0])
|
69 |
+
x2 = x2.clamp(min=0, max=featmap_size[1])
|
70 |
+
y2 = y2.clamp(min=0, max=featmap_size[0])
|
71 |
+
return (x1, y1, x2, y2)
|
mmdet/core/bbox/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
|
2 |
+
MaxIoUAssigner, RegionAssigner)
|
3 |
+
from .builder import build_assigner, build_bbox_coder, build_sampler
|
4 |
+
from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
|
5 |
+
TBLRBBoxCoder)
|
6 |
+
from .iou_calculators import BboxOverlaps2D, bbox_overlaps
|
7 |
+
from .samplers import (BaseSampler, CombinedSampler,
|
8 |
+
InstanceBalancedPosSampler, IoUBalancedNegSampler,
|
9 |
+
OHEMSampler, PseudoSampler, RandomSampler,
|
10 |
+
SamplingResult, ScoreHLRSampler)
|
11 |
+
from .transforms import (bbox2distance, bbox2result, bbox2roi,
|
12 |
+
bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
|
13 |
+
bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
|
14 |
+
distance2bbox, roi2bbox)
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
|
18 |
+
'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
|
19 |
+
'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
|
20 |
+
'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
|
21 |
+
'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
|
22 |
+
'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
|
23 |
+
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
|
24 |
+
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
|
25 |
+
'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
|
26 |
+
'RegionAssigner'
|
27 |
+
]
|
mmdet/core/bbox/assigners/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .approx_max_iou_assigner import ApproxMaxIoUAssigner
|
2 |
+
from .assign_result import AssignResult
|
3 |
+
from .atss_assigner import ATSSAssigner
|
4 |
+
from .base_assigner import BaseAssigner
|
5 |
+
from .center_region_assigner import CenterRegionAssigner
|
6 |
+
from .grid_assigner import GridAssigner
|
7 |
+
from .hungarian_assigner import HungarianAssigner
|
8 |
+
from .max_iou_assigner import MaxIoUAssigner
|
9 |
+
from .point_assigner import PointAssigner
|
10 |
+
from .region_assigner import RegionAssigner
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
|
14 |
+
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
|
15 |
+
'HungarianAssigner', 'RegionAssigner'
|
16 |
+
]
|
mmdet/core/bbox/assigners/approx_max_iou_assigner.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .max_iou_assigner import MaxIoUAssigner
|
6 |
+
|
7 |
+
|
8 |
+
@BBOX_ASSIGNERS.register_module()
|
9 |
+
class ApproxMaxIoUAssigner(MaxIoUAssigner):
|
10 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
11 |
+
|
12 |
+
Each proposals will be assigned with an integer indicating the ground truth
|
13 |
+
index. (semi-positive index: gt label (0-based), -1: background)
|
14 |
+
|
15 |
+
- -1: negative sample, no assigned gt
|
16 |
+
- semi-positive integer: positive sample, index (0-based) of assigned gt
|
17 |
+
|
18 |
+
Args:
|
19 |
+
pos_iou_thr (float): IoU threshold for positive bboxes.
|
20 |
+
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
|
21 |
+
min_pos_iou (float): Minimum iou for a bbox to be considered as a
|
22 |
+
positive bbox. Positive samples can have smaller IoU than
|
23 |
+
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
|
24 |
+
gt_max_assign_all (bool): Whether to assign all bboxes with the same
|
25 |
+
highest overlap with some gt to that gt.
|
26 |
+
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
|
27 |
+
`gt_bboxes_ignore` is specified). Negative values mean not
|
28 |
+
ignoring any bboxes.
|
29 |
+
ignore_wrt_candidates (bool): Whether to compute the iof between
|
30 |
+
`bboxes` and `gt_bboxes_ignore`, or the contrary.
|
31 |
+
match_low_quality (bool): Whether to allow quality matches. This is
|
32 |
+
usually allowed for RPN and single stage detectors, but not allowed
|
33 |
+
in the second stage.
|
34 |
+
gpu_assign_thr (int): The upper bound of the number of GT for GPU
|
35 |
+
assign. When the number of gt is above this threshold, will assign
|
36 |
+
on CPU device. Negative values mean not assign on CPU.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
pos_iou_thr,
|
41 |
+
neg_iou_thr,
|
42 |
+
min_pos_iou=.0,
|
43 |
+
gt_max_assign_all=True,
|
44 |
+
ignore_iof_thr=-1,
|
45 |
+
ignore_wrt_candidates=True,
|
46 |
+
match_low_quality=True,
|
47 |
+
gpu_assign_thr=-1,
|
48 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
49 |
+
self.pos_iou_thr = pos_iou_thr
|
50 |
+
self.neg_iou_thr = neg_iou_thr
|
51 |
+
self.min_pos_iou = min_pos_iou
|
52 |
+
self.gt_max_assign_all = gt_max_assign_all
|
53 |
+
self.ignore_iof_thr = ignore_iof_thr
|
54 |
+
self.ignore_wrt_candidates = ignore_wrt_candidates
|
55 |
+
self.gpu_assign_thr = gpu_assign_thr
|
56 |
+
self.match_low_quality = match_low_quality
|
57 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
58 |
+
|
59 |
+
def assign(self,
|
60 |
+
approxs,
|
61 |
+
squares,
|
62 |
+
approxs_per_octave,
|
63 |
+
gt_bboxes,
|
64 |
+
gt_bboxes_ignore=None,
|
65 |
+
gt_labels=None):
|
66 |
+
"""Assign gt to approxs.
|
67 |
+
|
68 |
+
This method assign a gt bbox to each group of approxs (bboxes),
|
69 |
+
each group of approxs is represent by a base approx (bbox) and
|
70 |
+
will be assigned with -1, or a semi-positive number.
|
71 |
+
background_label (-1) means negative sample,
|
72 |
+
semi-positive number is the index (0-based) of assigned gt.
|
73 |
+
The assignment is done in following steps, the order matters.
|
74 |
+
|
75 |
+
1. assign every bbox to background_label (-1)
|
76 |
+
2. use the max IoU of each group of approxs to assign
|
77 |
+
2. assign proposals whose iou with all gts < neg_iou_thr to background
|
78 |
+
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
|
79 |
+
assign it to that bbox
|
80 |
+
4. for each gt bbox, assign its nearest proposals (may be more than
|
81 |
+
one) to itself
|
82 |
+
|
83 |
+
Args:
|
84 |
+
approxs (Tensor): Bounding boxes to be assigned,
|
85 |
+
shape(approxs_per_octave*n, 4).
|
86 |
+
squares (Tensor): Base Bounding boxes to be assigned,
|
87 |
+
shape(n, 4).
|
88 |
+
approxs_per_octave (int): number of approxs per octave
|
89 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
90 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
91 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
92 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
:obj:`AssignResult`: The assign result.
|
96 |
+
"""
|
97 |
+
num_squares = squares.size(0)
|
98 |
+
num_gts = gt_bboxes.size(0)
|
99 |
+
|
100 |
+
if num_squares == 0 or num_gts == 0:
|
101 |
+
# No predictions and/or truth, return empty assignment
|
102 |
+
overlaps = approxs.new(num_gts, num_squares)
|
103 |
+
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
|
104 |
+
return assign_result
|
105 |
+
|
106 |
+
# re-organize anchors by approxs_per_octave x num_squares
|
107 |
+
approxs = torch.transpose(
|
108 |
+
approxs.view(num_squares, approxs_per_octave, 4), 0,
|
109 |
+
1).contiguous().view(-1, 4)
|
110 |
+
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
|
111 |
+
num_gts > self.gpu_assign_thr) else False
|
112 |
+
# compute overlap and assign gt on CPU when number of GT is large
|
113 |
+
if assign_on_cpu:
|
114 |
+
device = approxs.device
|
115 |
+
approxs = approxs.cpu()
|
116 |
+
gt_bboxes = gt_bboxes.cpu()
|
117 |
+
if gt_bboxes_ignore is not None:
|
118 |
+
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
|
119 |
+
if gt_labels is not None:
|
120 |
+
gt_labels = gt_labels.cpu()
|
121 |
+
all_overlaps = self.iou_calculator(approxs, gt_bboxes)
|
122 |
+
|
123 |
+
overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
|
124 |
+
num_gts).max(dim=0)
|
125 |
+
overlaps = torch.transpose(overlaps, 0, 1)
|
126 |
+
|
127 |
+
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
|
128 |
+
and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
|
129 |
+
if self.ignore_wrt_candidates:
|
130 |
+
ignore_overlaps = self.iou_calculator(
|
131 |
+
squares, gt_bboxes_ignore, mode='iof')
|
132 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
|
133 |
+
else:
|
134 |
+
ignore_overlaps = self.iou_calculator(
|
135 |
+
gt_bboxes_ignore, squares, mode='iof')
|
136 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
|
137 |
+
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
|
138 |
+
|
139 |
+
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
|
140 |
+
if assign_on_cpu:
|
141 |
+
assign_result.gt_inds = assign_result.gt_inds.to(device)
|
142 |
+
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
|
143 |
+
if assign_result.labels is not None:
|
144 |
+
assign_result.labels = assign_result.labels.to(device)
|
145 |
+
return assign_result
|
mmdet/core/bbox/assigners/assign_result.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from mmdet.utils import util_mixins
|
4 |
+
|
5 |
+
|
6 |
+
class AssignResult(util_mixins.NiceRepr):
|
7 |
+
"""Stores assignments between predicted and truth boxes.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
num_gts (int): the number of truth boxes considered when computing this
|
11 |
+
assignment
|
12 |
+
|
13 |
+
gt_inds (LongTensor): for each predicted box indicates the 1-based
|
14 |
+
index of the assigned truth box. 0 means unassigned and -1 means
|
15 |
+
ignore.
|
16 |
+
|
17 |
+
max_overlaps (FloatTensor): the iou between the predicted box and its
|
18 |
+
assigned truth box.
|
19 |
+
|
20 |
+
labels (None | LongTensor): If specified, for each predicted box
|
21 |
+
indicates the category label of the assigned truth box.
|
22 |
+
|
23 |
+
Example:
|
24 |
+
>>> # An assign result between 4 predicted boxes and 9 true boxes
|
25 |
+
>>> # where only two boxes were assigned.
|
26 |
+
>>> num_gts = 9
|
27 |
+
>>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
|
28 |
+
>>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
|
29 |
+
>>> labels = torch.LongTensor([0, 3, 4, 0])
|
30 |
+
>>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
|
31 |
+
>>> print(str(self)) # xdoctest: +IGNORE_WANT
|
32 |
+
<AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
|
33 |
+
labels.shape=(4,))>
|
34 |
+
>>> # Force addition of gt labels (when adding gt as proposals)
|
35 |
+
>>> new_labels = torch.LongTensor([3, 4, 5])
|
36 |
+
>>> self.add_gt_(new_labels)
|
37 |
+
>>> print(str(self)) # xdoctest: +IGNORE_WANT
|
38 |
+
<AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
|
39 |
+
labels.shape=(7,))>
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
|
43 |
+
self.num_gts = num_gts
|
44 |
+
self.gt_inds = gt_inds
|
45 |
+
self.max_overlaps = max_overlaps
|
46 |
+
self.labels = labels
|
47 |
+
# Interface for possible user-defined properties
|
48 |
+
self._extra_properties = {}
|
49 |
+
|
50 |
+
@property
|
51 |
+
def num_preds(self):
|
52 |
+
"""int: the number of predictions in this assignment"""
|
53 |
+
return len(self.gt_inds)
|
54 |
+
|
55 |
+
def set_extra_property(self, key, value):
|
56 |
+
"""Set user-defined new property."""
|
57 |
+
assert key not in self.info
|
58 |
+
self._extra_properties[key] = value
|
59 |
+
|
60 |
+
def get_extra_property(self, key):
|
61 |
+
"""Get user-defined property."""
|
62 |
+
return self._extra_properties.get(key, None)
|
63 |
+
|
64 |
+
@property
|
65 |
+
def info(self):
|
66 |
+
"""dict: a dictionary of info about the object"""
|
67 |
+
basic_info = {
|
68 |
+
'num_gts': self.num_gts,
|
69 |
+
'num_preds': self.num_preds,
|
70 |
+
'gt_inds': self.gt_inds,
|
71 |
+
'max_overlaps': self.max_overlaps,
|
72 |
+
'labels': self.labels,
|
73 |
+
}
|
74 |
+
basic_info.update(self._extra_properties)
|
75 |
+
return basic_info
|
76 |
+
|
77 |
+
def __nice__(self):
|
78 |
+
"""str: a "nice" summary string describing this assign result"""
|
79 |
+
parts = []
|
80 |
+
parts.append(f'num_gts={self.num_gts!r}')
|
81 |
+
if self.gt_inds is None:
|
82 |
+
parts.append(f'gt_inds={self.gt_inds!r}')
|
83 |
+
else:
|
84 |
+
parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
|
85 |
+
if self.max_overlaps is None:
|
86 |
+
parts.append(f'max_overlaps={self.max_overlaps!r}')
|
87 |
+
else:
|
88 |
+
parts.append('max_overlaps.shape='
|
89 |
+
f'{tuple(self.max_overlaps.shape)!r}')
|
90 |
+
if self.labels is None:
|
91 |
+
parts.append(f'labels={self.labels!r}')
|
92 |
+
else:
|
93 |
+
parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
|
94 |
+
return ', '.join(parts)
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def random(cls, **kwargs):
|
98 |
+
"""Create random AssignResult for tests or debugging.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
num_preds: number of predicted boxes
|
102 |
+
num_gts: number of true boxes
|
103 |
+
p_ignore (float): probability of a predicted box assinged to an
|
104 |
+
ignored truth
|
105 |
+
p_assigned (float): probability of a predicted box not being
|
106 |
+
assigned
|
107 |
+
p_use_label (float | bool): with labels or not
|
108 |
+
rng (None | int | numpy.random.RandomState): seed or state
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
:obj:`AssignResult`: Randomly generated assign results.
|
112 |
+
|
113 |
+
Example:
|
114 |
+
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
|
115 |
+
>>> self = AssignResult.random()
|
116 |
+
>>> print(self.info)
|
117 |
+
"""
|
118 |
+
from mmdet.core.bbox import demodata
|
119 |
+
rng = demodata.ensure_rng(kwargs.get('rng', None))
|
120 |
+
|
121 |
+
num_gts = kwargs.get('num_gts', None)
|
122 |
+
num_preds = kwargs.get('num_preds', None)
|
123 |
+
p_ignore = kwargs.get('p_ignore', 0.3)
|
124 |
+
p_assigned = kwargs.get('p_assigned', 0.7)
|
125 |
+
p_use_label = kwargs.get('p_use_label', 0.5)
|
126 |
+
num_classes = kwargs.get('p_use_label', 3)
|
127 |
+
|
128 |
+
if num_gts is None:
|
129 |
+
num_gts = rng.randint(0, 8)
|
130 |
+
if num_preds is None:
|
131 |
+
num_preds = rng.randint(0, 16)
|
132 |
+
|
133 |
+
if num_gts == 0:
|
134 |
+
max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
|
135 |
+
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
|
136 |
+
if p_use_label is True or p_use_label < rng.rand():
|
137 |
+
labels = torch.zeros(num_preds, dtype=torch.int64)
|
138 |
+
else:
|
139 |
+
labels = None
|
140 |
+
else:
|
141 |
+
import numpy as np
|
142 |
+
# Create an overlap for each predicted box
|
143 |
+
max_overlaps = torch.from_numpy(rng.rand(num_preds))
|
144 |
+
|
145 |
+
# Construct gt_inds for each predicted box
|
146 |
+
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
|
147 |
+
# maximum number of assignments constraints
|
148 |
+
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
|
149 |
+
|
150 |
+
assigned_idxs = np.where(is_assigned)[0]
|
151 |
+
rng.shuffle(assigned_idxs)
|
152 |
+
assigned_idxs = assigned_idxs[0:n_assigned]
|
153 |
+
assigned_idxs.sort()
|
154 |
+
|
155 |
+
is_assigned[:] = 0
|
156 |
+
is_assigned[assigned_idxs] = True
|
157 |
+
|
158 |
+
is_ignore = torch.from_numpy(
|
159 |
+
rng.rand(num_preds) < p_ignore) & is_assigned
|
160 |
+
|
161 |
+
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
|
162 |
+
|
163 |
+
true_idxs = np.arange(num_gts)
|
164 |
+
rng.shuffle(true_idxs)
|
165 |
+
true_idxs = torch.from_numpy(true_idxs)
|
166 |
+
gt_inds[is_assigned] = true_idxs[:n_assigned]
|
167 |
+
|
168 |
+
gt_inds = torch.from_numpy(
|
169 |
+
rng.randint(1, num_gts + 1, size=num_preds))
|
170 |
+
gt_inds[is_ignore] = -1
|
171 |
+
gt_inds[~is_assigned] = 0
|
172 |
+
max_overlaps[~is_assigned] = 0
|
173 |
+
|
174 |
+
if p_use_label is True or p_use_label < rng.rand():
|
175 |
+
if num_classes == 0:
|
176 |
+
labels = torch.zeros(num_preds, dtype=torch.int64)
|
177 |
+
else:
|
178 |
+
labels = torch.from_numpy(
|
179 |
+
# remind that we set FG labels to [0, num_class-1]
|
180 |
+
# since mmdet v2.0
|
181 |
+
# BG cat_id: num_class
|
182 |
+
rng.randint(0, num_classes, size=num_preds))
|
183 |
+
labels[~is_assigned] = 0
|
184 |
+
else:
|
185 |
+
labels = None
|
186 |
+
|
187 |
+
self = cls(num_gts, gt_inds, max_overlaps, labels)
|
188 |
+
return self
|
189 |
+
|
190 |
+
def add_gt_(self, gt_labels):
|
191 |
+
"""Add ground truth as assigned results.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
gt_labels (torch.Tensor): Labels of gt boxes
|
195 |
+
"""
|
196 |
+
self_inds = torch.arange(
|
197 |
+
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
|
198 |
+
self.gt_inds = torch.cat([self_inds, self.gt_inds])
|
199 |
+
|
200 |
+
self.max_overlaps = torch.cat(
|
201 |
+
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
|
202 |
+
|
203 |
+
if self.labels is not None:
|
204 |
+
self.labels = torch.cat([gt_labels, self.labels])
|
mmdet/core/bbox/assigners/atss_assigner.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_ASSIGNERS.register_module()
|
10 |
+
class ATSSAssigner(BaseAssigner):
|
11 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
12 |
+
|
13 |
+
Each proposals will be assigned with `0` or a positive integer
|
14 |
+
indicating the ground truth index.
|
15 |
+
|
16 |
+
- 0: negative sample, no assigned gt
|
17 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
18 |
+
|
19 |
+
Args:
|
20 |
+
topk (float): number of bbox selected in each level
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
topk,
|
25 |
+
iou_calculator=dict(type='BboxOverlaps2D'),
|
26 |
+
ignore_iof_thr=-1):
|
27 |
+
self.topk = topk
|
28 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
29 |
+
self.ignore_iof_thr = ignore_iof_thr
|
30 |
+
|
31 |
+
# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
|
32 |
+
|
33 |
+
def assign(self,
|
34 |
+
bboxes,
|
35 |
+
num_level_bboxes,
|
36 |
+
gt_bboxes,
|
37 |
+
gt_bboxes_ignore=None,
|
38 |
+
gt_labels=None):
|
39 |
+
"""Assign gt to bboxes.
|
40 |
+
|
41 |
+
The assignment is done in following steps
|
42 |
+
|
43 |
+
1. compute iou between all bbox (bbox of all pyramid levels) and gt
|
44 |
+
2. compute center distance between all bbox and gt
|
45 |
+
3. on each pyramid level, for each gt, select k bbox whose center
|
46 |
+
are closest to the gt center, so we total select k*l bbox as
|
47 |
+
candidates for each gt
|
48 |
+
4. get corresponding iou for the these candidates, and compute the
|
49 |
+
mean and std, set mean + std as the iou threshold
|
50 |
+
5. select these candidates whose iou are greater than or equal to
|
51 |
+
the threshold as positive
|
52 |
+
6. limit the positive sample's center in gt
|
53 |
+
|
54 |
+
|
55 |
+
Args:
|
56 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
57 |
+
num_level_bboxes (List): num of bboxes in each level
|
58 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
59 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
60 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
61 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
:obj:`AssignResult`: The assign result.
|
65 |
+
"""
|
66 |
+
INF = 100000000
|
67 |
+
bboxes = bboxes[:, :4]
|
68 |
+
num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
69 |
+
|
70 |
+
# compute iou between all bbox and gt
|
71 |
+
overlaps = self.iou_calculator(bboxes, gt_bboxes)
|
72 |
+
|
73 |
+
# assign 0 by default
|
74 |
+
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
|
75 |
+
0,
|
76 |
+
dtype=torch.long)
|
77 |
+
|
78 |
+
if num_gt == 0 or num_bboxes == 0:
|
79 |
+
# No ground truth or boxes, return empty assignment
|
80 |
+
max_overlaps = overlaps.new_zeros((num_bboxes, ))
|
81 |
+
if num_gt == 0:
|
82 |
+
# No truth, assign everything to background
|
83 |
+
assigned_gt_inds[:] = 0
|
84 |
+
if gt_labels is None:
|
85 |
+
assigned_labels = None
|
86 |
+
else:
|
87 |
+
assigned_labels = overlaps.new_full((num_bboxes, ),
|
88 |
+
-1,
|
89 |
+
dtype=torch.long)
|
90 |
+
return AssignResult(
|
91 |
+
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
92 |
+
|
93 |
+
# compute center distance between all bbox and gt
|
94 |
+
gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
|
95 |
+
gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
|
96 |
+
gt_points = torch.stack((gt_cx, gt_cy), dim=1)
|
97 |
+
|
98 |
+
bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
|
99 |
+
bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
|
100 |
+
bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
|
101 |
+
|
102 |
+
distances = (bboxes_points[:, None, :] -
|
103 |
+
gt_points[None, :, :]).pow(2).sum(-1).sqrt()
|
104 |
+
|
105 |
+
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
|
106 |
+
and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
|
107 |
+
ignore_overlaps = self.iou_calculator(
|
108 |
+
bboxes, gt_bboxes_ignore, mode='iof')
|
109 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
|
110 |
+
ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
|
111 |
+
distances[ignore_idxs, :] = INF
|
112 |
+
assigned_gt_inds[ignore_idxs] = -1
|
113 |
+
|
114 |
+
# Selecting candidates based on the center distance
|
115 |
+
candidate_idxs = []
|
116 |
+
start_idx = 0
|
117 |
+
for level, bboxes_per_level in enumerate(num_level_bboxes):
|
118 |
+
# on each pyramid level, for each gt,
|
119 |
+
# select k bbox whose center are closest to the gt center
|
120 |
+
end_idx = start_idx + bboxes_per_level
|
121 |
+
distances_per_level = distances[start_idx:end_idx, :]
|
122 |
+
selectable_k = min(self.topk, bboxes_per_level)
|
123 |
+
_, topk_idxs_per_level = distances_per_level.topk(
|
124 |
+
selectable_k, dim=0, largest=False)
|
125 |
+
candidate_idxs.append(topk_idxs_per_level + start_idx)
|
126 |
+
start_idx = end_idx
|
127 |
+
candidate_idxs = torch.cat(candidate_idxs, dim=0)
|
128 |
+
|
129 |
+
# get corresponding iou for the these candidates, and compute the
|
130 |
+
# mean and std, set mean + std as the iou threshold
|
131 |
+
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
|
132 |
+
overlaps_mean_per_gt = candidate_overlaps.mean(0)
|
133 |
+
overlaps_std_per_gt = candidate_overlaps.std(0)
|
134 |
+
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
|
135 |
+
|
136 |
+
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
|
137 |
+
|
138 |
+
# limit the positive sample's center in gt
|
139 |
+
for gt_idx in range(num_gt):
|
140 |
+
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
|
141 |
+
ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
|
142 |
+
num_gt, num_bboxes).contiguous().view(-1)
|
143 |
+
ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
|
144 |
+
num_gt, num_bboxes).contiguous().view(-1)
|
145 |
+
candidate_idxs = candidate_idxs.view(-1)
|
146 |
+
|
147 |
+
# calculate the left, top, right, bottom distance between positive
|
148 |
+
# bbox center and gt side
|
149 |
+
l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
|
150 |
+
t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
|
151 |
+
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
|
152 |
+
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
|
153 |
+
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
|
154 |
+
is_pos = is_pos & is_in_gts
|
155 |
+
|
156 |
+
# if an anchor box is assigned to multiple gts,
|
157 |
+
# the one with the highest IoU will be selected.
|
158 |
+
overlaps_inf = torch.full_like(overlaps,
|
159 |
+
-INF).t().contiguous().view(-1)
|
160 |
+
index = candidate_idxs.view(-1)[is_pos.view(-1)]
|
161 |
+
overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
|
162 |
+
overlaps_inf = overlaps_inf.view(num_gt, -1).t()
|
163 |
+
|
164 |
+
max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
|
165 |
+
assigned_gt_inds[
|
166 |
+
max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
|
167 |
+
|
168 |
+
if gt_labels is not None:
|
169 |
+
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
|
170 |
+
pos_inds = torch.nonzero(
|
171 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
172 |
+
if pos_inds.numel() > 0:
|
173 |
+
assigned_labels[pos_inds] = gt_labels[
|
174 |
+
assigned_gt_inds[pos_inds] - 1]
|
175 |
+
else:
|
176 |
+
assigned_labels = None
|
177 |
+
return AssignResult(
|
178 |
+
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
mmdet/core/bbox/assigners/base_assigner.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class BaseAssigner(metaclass=ABCMeta):
|
5 |
+
"""Base assigner that assigns boxes to ground truth boxes."""
|
6 |
+
|
7 |
+
@abstractmethod
|
8 |
+
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
9 |
+
"""Assign boxes to either a ground truth boxes or a negative boxes."""
|
mmdet/core/bbox/assigners/center_region_assigner.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
def scale_boxes(bboxes, scale):
|
10 |
+
"""Expand an array of boxes by a given scale.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
bboxes (Tensor): Shape (m, 4)
|
14 |
+
scale (float): The scale factor of bboxes
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
(Tensor): Shape (m, 4). Scaled bboxes
|
18 |
+
"""
|
19 |
+
assert bboxes.size(1) == 4
|
20 |
+
w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
|
21 |
+
h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
|
22 |
+
x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
|
23 |
+
y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
|
24 |
+
|
25 |
+
w_half *= scale
|
26 |
+
h_half *= scale
|
27 |
+
|
28 |
+
boxes_scaled = torch.zeros_like(bboxes)
|
29 |
+
boxes_scaled[:, 0] = x_c - w_half
|
30 |
+
boxes_scaled[:, 2] = x_c + w_half
|
31 |
+
boxes_scaled[:, 1] = y_c - h_half
|
32 |
+
boxes_scaled[:, 3] = y_c + h_half
|
33 |
+
return boxes_scaled
|
34 |
+
|
35 |
+
|
36 |
+
def is_located_in(points, bboxes):
|
37 |
+
"""Are points located in bboxes.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
points (Tensor): Points, shape: (m, 2).
|
41 |
+
bboxes (Tensor): Bounding boxes, shape: (n, 4).
|
42 |
+
|
43 |
+
Return:
|
44 |
+
Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
|
45 |
+
"""
|
46 |
+
assert points.size(1) == 2
|
47 |
+
assert bboxes.size(1) == 4
|
48 |
+
return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
|
49 |
+
(points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
|
50 |
+
(points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
|
51 |
+
(points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
|
52 |
+
|
53 |
+
|
54 |
+
def bboxes_area(bboxes):
|
55 |
+
"""Compute the area of an array of bboxes.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tensor: Area of the bboxes. Shape: (m, )
|
62 |
+
"""
|
63 |
+
assert bboxes.size(1) == 4
|
64 |
+
w = (bboxes[:, 2] - bboxes[:, 0])
|
65 |
+
h = (bboxes[:, 3] - bboxes[:, 1])
|
66 |
+
areas = w * h
|
67 |
+
return areas
|
68 |
+
|
69 |
+
|
70 |
+
@BBOX_ASSIGNERS.register_module()
|
71 |
+
class CenterRegionAssigner(BaseAssigner):
|
72 |
+
"""Assign pixels at the center region of a bbox as positive.
|
73 |
+
|
74 |
+
Each proposals will be assigned with `-1`, `0`, or a positive integer
|
75 |
+
indicating the ground truth index.
|
76 |
+
- -1: negative samples
|
77 |
+
- semi-positive numbers: positive sample, index (0-based) of assigned gt
|
78 |
+
|
79 |
+
Args:
|
80 |
+
pos_scale (float): Threshold within which pixels are
|
81 |
+
labelled as positive.
|
82 |
+
neg_scale (float): Threshold above which pixels are
|
83 |
+
labelled as positive.
|
84 |
+
min_pos_iof (float): Minimum iof of a pixel with a gt to be
|
85 |
+
labelled as positive. Default: 1e-2
|
86 |
+
ignore_gt_scale (float): Threshold within which the pixels
|
87 |
+
are ignored when the gt is labelled as shadowed. Default: 0.5
|
88 |
+
foreground_dominate (bool): If True, the bbox will be assigned as
|
89 |
+
positive when a gt's kernel region overlaps with another's shadowed
|
90 |
+
(ignored) region, otherwise it is set as ignored. Default to False.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self,
|
94 |
+
pos_scale,
|
95 |
+
neg_scale,
|
96 |
+
min_pos_iof=1e-2,
|
97 |
+
ignore_gt_scale=0.5,
|
98 |
+
foreground_dominate=False,
|
99 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
100 |
+
self.pos_scale = pos_scale
|
101 |
+
self.neg_scale = neg_scale
|
102 |
+
self.min_pos_iof = min_pos_iof
|
103 |
+
self.ignore_gt_scale = ignore_gt_scale
|
104 |
+
self.foreground_dominate = foreground_dominate
|
105 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
106 |
+
|
107 |
+
def get_gt_priorities(self, gt_bboxes):
|
108 |
+
"""Get gt priorities according to their areas.
|
109 |
+
|
110 |
+
Smaller gt has higher priority.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Tensor: The priority of gts so that gts with larger priority is \
|
117 |
+
more likely to be assigned. Shape (k, )
|
118 |
+
"""
|
119 |
+
gt_areas = bboxes_area(gt_bboxes)
|
120 |
+
# Rank all gt bbox areas. Smaller objects has larger priority
|
121 |
+
_, sort_idx = gt_areas.sort(descending=True)
|
122 |
+
sort_idx = sort_idx.argsort()
|
123 |
+
return sort_idx
|
124 |
+
|
125 |
+
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
126 |
+
"""Assign gt to bboxes.
|
127 |
+
|
128 |
+
This method assigns gts to every bbox (proposal/anchor), each bbox \
|
129 |
+
will be assigned with -1, or a semi-positive number. -1 means \
|
130 |
+
negative sample, semi-positive number is the index (0-based) of \
|
131 |
+
assigned gt.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
135 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
136 |
+
gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
|
137 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
138 |
+
gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
:obj:`AssignResult`: The assigned result. Note that \
|
142 |
+
shadowed_labels of shape (N, 2) is also added as an \
|
143 |
+
`assign_result` attribute. `shadowed_labels` is a tensor \
|
144 |
+
composed of N pairs of anchor_ind, class_label], where N \
|
145 |
+
is the number of anchors that lie in the outer region of a \
|
146 |
+
gt, anchor_ind is the shadowed anchor index and class_label \
|
147 |
+
is the shadowed class label.
|
148 |
+
|
149 |
+
Example:
|
150 |
+
>>> self = CenterRegionAssigner(0.2, 0.2)
|
151 |
+
>>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
|
152 |
+
>>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
|
153 |
+
>>> assign_result = self.assign(bboxes, gt_bboxes)
|
154 |
+
>>> expected_gt_inds = torch.LongTensor([1, 0])
|
155 |
+
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
|
156 |
+
"""
|
157 |
+
# There are in total 5 steps in the pixel assignment
|
158 |
+
# 1. Find core (the center region, say inner 0.2)
|
159 |
+
# and shadow (the relatively ourter part, say inner 0.2-0.5)
|
160 |
+
# regions of every gt.
|
161 |
+
# 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
|
162 |
+
# 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
|
163 |
+
# the image.
|
164 |
+
# 3.1. For overlapping objects, the prior bboxes in gt_core is
|
165 |
+
# assigned with the object with smallest area
|
166 |
+
# 4. Assign prior bboxes with class label according to its gt id.
|
167 |
+
# 4.1. Assign -1 to prior bboxes lying in shadowed gts
|
168 |
+
# 4.2. Assign positive prior boxes with the corresponding label
|
169 |
+
# 5. Find pixels lying in the shadow of an object and assign them with
|
170 |
+
# background label, but set the loss weight of its corresponding
|
171 |
+
# gt to zero.
|
172 |
+
assert bboxes.size(1) == 4, 'bboxes must have size of 4'
|
173 |
+
# 1. Find core positive and shadow region of every gt
|
174 |
+
gt_core = scale_boxes(gt_bboxes, self.pos_scale)
|
175 |
+
gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
|
176 |
+
|
177 |
+
# 2. Find prior bboxes that lie in gt_core and gt_shadow regions
|
178 |
+
bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
|
179 |
+
# The center points lie within the gt boxes
|
180 |
+
is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
|
181 |
+
# Only calculate bbox and gt_core IoF. This enables small prior bboxes
|
182 |
+
# to match large gts
|
183 |
+
bbox_and_gt_core_overlaps = self.iou_calculator(
|
184 |
+
bboxes, gt_core, mode='iof')
|
185 |
+
# The center point of effective priors should be within the gt box
|
186 |
+
is_bbox_in_gt_core = is_bbox_in_gt & (
|
187 |
+
bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
|
188 |
+
|
189 |
+
is_bbox_in_gt_shadow = (
|
190 |
+
self.iou_calculator(bboxes, gt_shadow, mode='iof') >
|
191 |
+
self.min_pos_iof)
|
192 |
+
# Rule out center effective positive pixels
|
193 |
+
is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
|
194 |
+
|
195 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
196 |
+
if num_gts == 0 or num_bboxes == 0:
|
197 |
+
# If no gts exist, assign all pixels to negative
|
198 |
+
assigned_gt_ids = \
|
199 |
+
is_bbox_in_gt_core.new_zeros((num_bboxes,),
|
200 |
+
dtype=torch.long)
|
201 |
+
pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
|
202 |
+
else:
|
203 |
+
# Step 3: assign a one-hot gt id to each pixel, and smaller objects
|
204 |
+
# have high priority to assign the pixel.
|
205 |
+
sort_idx = self.get_gt_priorities(gt_bboxes)
|
206 |
+
assigned_gt_ids, pixels_in_gt_shadow = \
|
207 |
+
self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
|
208 |
+
is_bbox_in_gt_shadow,
|
209 |
+
gt_priority=sort_idx)
|
210 |
+
|
211 |
+
if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
|
212 |
+
# No ground truth or boxes, return empty assignment
|
213 |
+
gt_bboxes_ignore = scale_boxes(
|
214 |
+
gt_bboxes_ignore, scale=self.ignore_gt_scale)
|
215 |
+
is_bbox_in_ignored_gts = is_located_in(bbox_centers,
|
216 |
+
gt_bboxes_ignore)
|
217 |
+
is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
|
218 |
+
assigned_gt_ids[is_bbox_in_ignored_gts] = -1
|
219 |
+
|
220 |
+
# 4. Assign prior bboxes with class label according to its gt id.
|
221 |
+
assigned_labels = None
|
222 |
+
shadowed_pixel_labels = None
|
223 |
+
if gt_labels is not None:
|
224 |
+
# Default assigned label is the background (-1)
|
225 |
+
assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
|
226 |
+
pos_inds = torch.nonzero(
|
227 |
+
assigned_gt_ids > 0, as_tuple=False).squeeze()
|
228 |
+
if pos_inds.numel() > 0:
|
229 |
+
assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
|
230 |
+
- 1]
|
231 |
+
# 5. Find pixels lying in the shadow of an object
|
232 |
+
shadowed_pixel_labels = pixels_in_gt_shadow.clone()
|
233 |
+
if pixels_in_gt_shadow.numel() > 0:
|
234 |
+
pixel_idx, gt_idx =\
|
235 |
+
pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
|
236 |
+
assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
|
237 |
+
'Some pixels are dually assigned to ignore and gt!'
|
238 |
+
shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
|
239 |
+
override = (
|
240 |
+
assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
|
241 |
+
if self.foreground_dominate:
|
242 |
+
# When a pixel is both positive and shadowed, set it as pos
|
243 |
+
shadowed_pixel_labels = shadowed_pixel_labels[~override]
|
244 |
+
else:
|
245 |
+
# When a pixel is both pos and shadowed, set it as shadowed
|
246 |
+
assigned_labels[pixel_idx[override]] = -1
|
247 |
+
assigned_gt_ids[pixel_idx[override]] = 0
|
248 |
+
|
249 |
+
assign_result = AssignResult(
|
250 |
+
num_gts, assigned_gt_ids, None, labels=assigned_labels)
|
251 |
+
# Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
|
252 |
+
assign_result.set_extra_property('shadowed_labels',
|
253 |
+
shadowed_pixel_labels)
|
254 |
+
return assign_result
|
255 |
+
|
256 |
+
def assign_one_hot_gt_indices(self,
|
257 |
+
is_bbox_in_gt_core,
|
258 |
+
is_bbox_in_gt_shadow,
|
259 |
+
gt_priority=None):
|
260 |
+
"""Assign only one gt index to each prior box.
|
261 |
+
|
262 |
+
Gts with large gt_priority are more likely to be assigned.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
|
266 |
+
is in the core area of a gt (e.g. 0-0.2).
|
267 |
+
Shape: (num_prior, num_gt).
|
268 |
+
is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
|
269 |
+
center is in the shadowed area of a gt (e.g. 0.2-0.5).
|
270 |
+
Shape: (num_prior, num_gt).
|
271 |
+
gt_priority (Tensor): Priorities of gts. The gt with a higher
|
272 |
+
priority is more likely to be assigned to the bbox when the bbox
|
273 |
+
match with multiple gts. Shape: (num_gt, ).
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
|
277 |
+
|
278 |
+
- assigned_gt_inds: The assigned gt index of each prior bbox \
|
279 |
+
(i.e. index from 1 to num_gts). Shape: (num_prior, ).
|
280 |
+
- shadowed_gt_inds: shadowed gt indices. It is a tensor of \
|
281 |
+
shape (num_ignore, 2) with first column being the \
|
282 |
+
shadowed prior bbox indices and the second column the \
|
283 |
+
shadowed gt indices (1-based).
|
284 |
+
"""
|
285 |
+
num_bboxes, num_gts = is_bbox_in_gt_core.shape
|
286 |
+
|
287 |
+
if gt_priority is None:
|
288 |
+
gt_priority = torch.arange(
|
289 |
+
num_gts, device=is_bbox_in_gt_core.device)
|
290 |
+
assert gt_priority.size(0) == num_gts
|
291 |
+
# The bigger gt_priority, the more preferable to be assigned
|
292 |
+
# The assigned inds are by default 0 (background)
|
293 |
+
assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
|
294 |
+
dtype=torch.long)
|
295 |
+
# Shadowed bboxes are assigned to be background. But the corresponding
|
296 |
+
# label is ignored during loss calculation, which is done through
|
297 |
+
# shadowed_gt_inds
|
298 |
+
shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
|
299 |
+
if is_bbox_in_gt_core.sum() == 0: # No gt match
|
300 |
+
shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
|
301 |
+
return assigned_gt_inds, shadowed_gt_inds
|
302 |
+
|
303 |
+
# The priority of each prior box and gt pair. If one prior box is
|
304 |
+
# matched bo multiple gts. Only the pair with the highest priority
|
305 |
+
# is saved
|
306 |
+
pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
|
307 |
+
-1,
|
308 |
+
dtype=torch.long)
|
309 |
+
|
310 |
+
# Each bbox could match with multiple gts.
|
311 |
+
# The following codes deal with this situation
|
312 |
+
# Matched bboxes (to any gt). Shape: (num_pos_anchor, )
|
313 |
+
inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
|
314 |
+
# The matched gt index of each positive bbox. Length >= num_pos_anchor
|
315 |
+
# , since one bbox could match multiple gts
|
316 |
+
matched_bbox_gt_inds = torch.nonzero(
|
317 |
+
is_bbox_in_gt_core, as_tuple=False)[:, 1]
|
318 |
+
# Assign priority to each bbox-gt pair.
|
319 |
+
pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
|
320 |
+
_, argmax_priority = pair_priority[inds_of_match].max(dim=1)
|
321 |
+
assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
|
322 |
+
# Zero-out the assigned anchor box to filter the shadowed gt indices
|
323 |
+
is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
|
324 |
+
# Concat the shadowed indices due to overlapping with that out side of
|
325 |
+
# effective scale. shape: (total_num_ignore, 2)
|
326 |
+
shadowed_gt_inds = torch.cat(
|
327 |
+
(shadowed_gt_inds, torch.nonzero(
|
328 |
+
is_bbox_in_gt_core, as_tuple=False)),
|
329 |
+
dim=0)
|
330 |
+
# `is_bbox_in_gt_core` should be changed back to keep arguments intact.
|
331 |
+
is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
|
332 |
+
# 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
|
333 |
+
if shadowed_gt_inds.numel() > 0:
|
334 |
+
shadowed_gt_inds[:, 1] += 1
|
335 |
+
return assigned_gt_inds, shadowed_gt_inds
|
mmdet/core/bbox/assigners/grid_assigner.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_ASSIGNERS.register_module()
|
10 |
+
class GridAssigner(BaseAssigner):
|
11 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
12 |
+
|
13 |
+
Each proposals will be assigned with `-1`, `0`, or a positive integer
|
14 |
+
indicating the ground truth index.
|
15 |
+
|
16 |
+
- -1: don't care
|
17 |
+
- 0: negative sample, no assigned gt
|
18 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
19 |
+
|
20 |
+
Args:
|
21 |
+
pos_iou_thr (float): IoU threshold for positive bboxes.
|
22 |
+
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
|
23 |
+
min_pos_iou (float): Minimum iou for a bbox to be considered as a
|
24 |
+
positive bbox. Positive samples can have smaller IoU than
|
25 |
+
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
|
26 |
+
gt_max_assign_all (bool): Whether to assign all bboxes with the same
|
27 |
+
highest overlap with some gt to that gt.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
pos_iou_thr,
|
32 |
+
neg_iou_thr,
|
33 |
+
min_pos_iou=.0,
|
34 |
+
gt_max_assign_all=True,
|
35 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
36 |
+
self.pos_iou_thr = pos_iou_thr
|
37 |
+
self.neg_iou_thr = neg_iou_thr
|
38 |
+
self.min_pos_iou = min_pos_iou
|
39 |
+
self.gt_max_assign_all = gt_max_assign_all
|
40 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
41 |
+
|
42 |
+
def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
|
43 |
+
"""Assign gt to bboxes. The process is very much like the max iou
|
44 |
+
assigner, except that positive samples are constrained within the cell
|
45 |
+
that the gt boxes fell in.
|
46 |
+
|
47 |
+
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
|
48 |
+
will be assigned with -1, 0, or a positive number. -1 means don't care,
|
49 |
+
0 means negative sample, positive number is the index (1-based) of
|
50 |
+
assigned gt.
|
51 |
+
The assignment is done in following steps, the order matters.
|
52 |
+
|
53 |
+
1. assign every bbox to -1
|
54 |
+
2. assign proposals whose iou with all gts <= neg_iou_thr to 0
|
55 |
+
3. for each bbox within a cell, if the iou with its nearest gt >
|
56 |
+
pos_iou_thr and the center of that gt falls inside the cell,
|
57 |
+
assign it to that bbox
|
58 |
+
4. for each gt bbox, assign its nearest proposals within the cell the
|
59 |
+
gt bbox falls in to itself.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
63 |
+
box_responsible_flags (Tensor): flag to indicate whether box is
|
64 |
+
responsible for prediction, shape(n, )
|
65 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
66 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
:obj:`AssignResult`: The assign result.
|
70 |
+
"""
|
71 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
72 |
+
|
73 |
+
# compute iou between all gt and bboxes
|
74 |
+
overlaps = self.iou_calculator(gt_bboxes, bboxes)
|
75 |
+
|
76 |
+
# 1. assign -1 by default
|
77 |
+
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
|
78 |
+
-1,
|
79 |
+
dtype=torch.long)
|
80 |
+
|
81 |
+
if num_gts == 0 or num_bboxes == 0:
|
82 |
+
# No ground truth or boxes, return empty assignment
|
83 |
+
max_overlaps = overlaps.new_zeros((num_bboxes, ))
|
84 |
+
if num_gts == 0:
|
85 |
+
# No truth, assign everything to background
|
86 |
+
assigned_gt_inds[:] = 0
|
87 |
+
if gt_labels is None:
|
88 |
+
assigned_labels = None
|
89 |
+
else:
|
90 |
+
assigned_labels = overlaps.new_full((num_bboxes, ),
|
91 |
+
-1,
|
92 |
+
dtype=torch.long)
|
93 |
+
return AssignResult(
|
94 |
+
num_gts,
|
95 |
+
assigned_gt_inds,
|
96 |
+
max_overlaps,
|
97 |
+
labels=assigned_labels)
|
98 |
+
|
99 |
+
# 2. assign negative: below
|
100 |
+
# for each anchor, which gt best overlaps with it
|
101 |
+
# for each anchor, the max iou of all gts
|
102 |
+
# shape of max_overlaps == argmax_overlaps == num_bboxes
|
103 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
104 |
+
|
105 |
+
if isinstance(self.neg_iou_thr, float):
|
106 |
+
assigned_gt_inds[(max_overlaps >= 0)
|
107 |
+
& (max_overlaps <= self.neg_iou_thr)] = 0
|
108 |
+
elif isinstance(self.neg_iou_thr, (tuple, list)):
|
109 |
+
assert len(self.neg_iou_thr) == 2
|
110 |
+
assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
|
111 |
+
& (max_overlaps <= self.neg_iou_thr[1])] = 0
|
112 |
+
|
113 |
+
# 3. assign positive: falls into responsible cell and above
|
114 |
+
# positive IOU threshold, the order matters.
|
115 |
+
# the prior condition of comparision is to filter out all
|
116 |
+
# unrelated anchors, i.e. not box_responsible_flags
|
117 |
+
overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
|
118 |
+
|
119 |
+
# calculate max_overlaps again, but this time we only consider IOUs
|
120 |
+
# for anchors responsible for prediction
|
121 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
122 |
+
|
123 |
+
# for each gt, which anchor best overlaps with it
|
124 |
+
# for each gt, the max iou of all proposals
|
125 |
+
# shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
|
126 |
+
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
|
127 |
+
|
128 |
+
pos_inds = (max_overlaps >
|
129 |
+
self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
|
130 |
+
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
|
131 |
+
|
132 |
+
# 4. assign positive to max overlapped anchors within responsible cell
|
133 |
+
for i in range(num_gts):
|
134 |
+
if gt_max_overlaps[i] > self.min_pos_iou:
|
135 |
+
if self.gt_max_assign_all:
|
136 |
+
max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
|
137 |
+
box_responsible_flags.type(torch.bool)
|
138 |
+
assigned_gt_inds[max_iou_inds] = i + 1
|
139 |
+
elif box_responsible_flags[gt_argmax_overlaps[i]]:
|
140 |
+
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
|
141 |
+
|
142 |
+
# assign labels of positive anchors
|
143 |
+
if gt_labels is not None:
|
144 |
+
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
|
145 |
+
pos_inds = torch.nonzero(
|
146 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
147 |
+
if pos_inds.numel() > 0:
|
148 |
+
assigned_labels[pos_inds] = gt_labels[
|
149 |
+
assigned_gt_inds[pos_inds] - 1]
|
150 |
+
|
151 |
+
else:
|
152 |
+
assigned_labels = None
|
153 |
+
|
154 |
+
return AssignResult(
|
155 |
+
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
mmdet/core/bbox/assigners/hungarian_assigner.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..match_costs import build_match_cost
|
5 |
+
from ..transforms import bbox_cxcywh_to_xyxy
|
6 |
+
from .assign_result import AssignResult
|
7 |
+
from .base_assigner import BaseAssigner
|
8 |
+
|
9 |
+
try:
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
except ImportError:
|
12 |
+
linear_sum_assignment = None
|
13 |
+
|
14 |
+
|
15 |
+
@BBOX_ASSIGNERS.register_module()
|
16 |
+
class HungarianAssigner(BaseAssigner):
|
17 |
+
"""Computes one-to-one matching between predictions and ground truth.
|
18 |
+
|
19 |
+
This class computes an assignment between the targets and the predictions
|
20 |
+
based on the costs. The costs are weighted sum of three components:
|
21 |
+
classification cost, regression L1 cost and regression iou cost. The
|
22 |
+
targets don't include the no_object, so generally there are more
|
23 |
+
predictions than targets. After the one-to-one matching, the un-matched
|
24 |
+
are treated as backgrounds. Thus each query prediction will be assigned
|
25 |
+
with `0` or a positive integer indicating the ground truth index:
|
26 |
+
|
27 |
+
- 0: negative sample, no assigned gt
|
28 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
29 |
+
|
30 |
+
Args:
|
31 |
+
cls_weight (int | float, optional): The scale factor for classification
|
32 |
+
cost. Default 1.0.
|
33 |
+
bbox_weight (int | float, optional): The scale factor for regression
|
34 |
+
L1 cost. Default 1.0.
|
35 |
+
iou_weight (int | float, optional): The scale factor for regression
|
36 |
+
iou cost. Default 1.0.
|
37 |
+
iou_calculator (dict | optional): The config for the iou calculation.
|
38 |
+
Default type `BboxOverlaps2D`.
|
39 |
+
iou_mode (str | optional): "iou" (intersection over union), "iof"
|
40 |
+
(intersection over foreground), or "giou" (generalized
|
41 |
+
intersection over union). Default "giou".
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
cls_cost=dict(type='ClassificationCost', weight=1.),
|
46 |
+
reg_cost=dict(type='BBoxL1Cost', weight=1.0),
|
47 |
+
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
|
48 |
+
self.cls_cost = build_match_cost(cls_cost)
|
49 |
+
self.reg_cost = build_match_cost(reg_cost)
|
50 |
+
self.iou_cost = build_match_cost(iou_cost)
|
51 |
+
|
52 |
+
def assign(self,
|
53 |
+
bbox_pred,
|
54 |
+
cls_pred,
|
55 |
+
gt_bboxes,
|
56 |
+
gt_labels,
|
57 |
+
img_meta,
|
58 |
+
gt_bboxes_ignore=None,
|
59 |
+
eps=1e-7):
|
60 |
+
"""Computes one-to-one matching based on the weighted costs.
|
61 |
+
|
62 |
+
This method assign each query prediction to a ground truth or
|
63 |
+
background. The `assigned_gt_inds` with -1 means don't care,
|
64 |
+
0 means negative sample, and positive number is the index (1-based)
|
65 |
+
of assigned gt.
|
66 |
+
The assignment is done in the following steps, the order matters.
|
67 |
+
|
68 |
+
1. assign every prediction to -1
|
69 |
+
2. compute the weighted costs
|
70 |
+
3. do Hungarian matching on CPU based on the costs
|
71 |
+
4. assign all to 0 (background) first, then for each matched pair
|
72 |
+
between predictions and gts, treat this prediction as foreground
|
73 |
+
and assign the corresponding gt index (plus 1) to it.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
bbox_pred (Tensor): Predicted boxes with normalized coordinates
|
77 |
+
(cx, cy, w, h), which are all in range [0, 1]. Shape
|
78 |
+
[num_query, 4].
|
79 |
+
cls_pred (Tensor): Predicted classification logits, shape
|
80 |
+
[num_query, num_class].
|
81 |
+
gt_bboxes (Tensor): Ground truth boxes with unnormalized
|
82 |
+
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
|
83 |
+
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
|
84 |
+
img_meta (dict): Meta information for current image.
|
85 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
86 |
+
labelled as `ignored`. Default None.
|
87 |
+
eps (int | float, optional): A value added to the denominator for
|
88 |
+
numerical stability. Default 1e-7.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
:obj:`AssignResult`: The assigned result.
|
92 |
+
"""
|
93 |
+
assert gt_bboxes_ignore is None, \
|
94 |
+
'Only case when gt_bboxes_ignore is None is supported.'
|
95 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
|
96 |
+
|
97 |
+
# 1. assign -1 by default
|
98 |
+
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
|
99 |
+
-1,
|
100 |
+
dtype=torch.long)
|
101 |
+
assigned_labels = bbox_pred.new_full((num_bboxes, ),
|
102 |
+
-1,
|
103 |
+
dtype=torch.long)
|
104 |
+
if num_gts == 0 or num_bboxes == 0:
|
105 |
+
# No ground truth or boxes, return empty assignment
|
106 |
+
if num_gts == 0:
|
107 |
+
# No ground truth, assign all to background
|
108 |
+
assigned_gt_inds[:] = 0
|
109 |
+
return AssignResult(
|
110 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
111 |
+
img_h, img_w, _ = img_meta['img_shape']
|
112 |
+
factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
|
113 |
+
img_h]).unsqueeze(0)
|
114 |
+
|
115 |
+
# 2. compute the weighted costs
|
116 |
+
# classification and bboxcost.
|
117 |
+
cls_cost = self.cls_cost(cls_pred, gt_labels)
|
118 |
+
# regression L1 cost
|
119 |
+
normalize_gt_bboxes = gt_bboxes / factor
|
120 |
+
reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
|
121 |
+
# regression iou cost, defaultly giou is used in official DETR.
|
122 |
+
bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
|
123 |
+
iou_cost = self.iou_cost(bboxes, gt_bboxes)
|
124 |
+
# weighted sum of above three costs
|
125 |
+
cost = cls_cost + reg_cost + iou_cost
|
126 |
+
|
127 |
+
# 3. do Hungarian matching on CPU using linear_sum_assignment
|
128 |
+
cost = cost.detach().cpu()
|
129 |
+
if linear_sum_assignment is None:
|
130 |
+
raise ImportError('Please run "pip install scipy" '
|
131 |
+
'to install scipy first.')
|
132 |
+
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
|
133 |
+
matched_row_inds = torch.from_numpy(matched_row_inds).to(
|
134 |
+
bbox_pred.device)
|
135 |
+
matched_col_inds = torch.from_numpy(matched_col_inds).to(
|
136 |
+
bbox_pred.device)
|
137 |
+
|
138 |
+
# 4. assign backgrounds and foregrounds
|
139 |
+
# assign all indices to backgrounds first
|
140 |
+
assigned_gt_inds[:] = 0
|
141 |
+
# assign foregrounds based on matching results
|
142 |
+
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
|
143 |
+
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
|
144 |
+
return AssignResult(
|
145 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
mmdet/core/bbox/assigners/max_iou_assigner.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_ASSIGNERS.register_module()
|
10 |
+
class MaxIoUAssigner(BaseAssigner):
|
11 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
12 |
+
|
13 |
+
Each proposals will be assigned with `-1`, or a semi-positive integer
|
14 |
+
indicating the ground truth index.
|
15 |
+
|
16 |
+
- -1: negative sample, no assigned gt
|
17 |
+
- semi-positive integer: positive sample, index (0-based) of assigned gt
|
18 |
+
|
19 |
+
Args:
|
20 |
+
pos_iou_thr (float): IoU threshold for positive bboxes.
|
21 |
+
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
|
22 |
+
min_pos_iou (float): Minimum iou for a bbox to be considered as a
|
23 |
+
positive bbox. Positive samples can have smaller IoU than
|
24 |
+
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
|
25 |
+
gt_max_assign_all (bool): Whether to assign all bboxes with the same
|
26 |
+
highest overlap with some gt to that gt.
|
27 |
+
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
|
28 |
+
`gt_bboxes_ignore` is specified). Negative values mean not
|
29 |
+
ignoring any bboxes.
|
30 |
+
ignore_wrt_candidates (bool): Whether to compute the iof between
|
31 |
+
`bboxes` and `gt_bboxes_ignore`, or the contrary.
|
32 |
+
match_low_quality (bool): Whether to allow low quality matches. This is
|
33 |
+
usually allowed for RPN and single stage detectors, but not allowed
|
34 |
+
in the second stage. Details are demonstrated in Step 4.
|
35 |
+
gpu_assign_thr (int): The upper bound of the number of GT for GPU
|
36 |
+
assign. When the number of gt is above this threshold, will assign
|
37 |
+
on CPU device. Negative values mean not assign on CPU.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
pos_iou_thr,
|
42 |
+
neg_iou_thr,
|
43 |
+
min_pos_iou=.0,
|
44 |
+
gt_max_assign_all=True,
|
45 |
+
ignore_iof_thr=-1,
|
46 |
+
ignore_wrt_candidates=True,
|
47 |
+
match_low_quality=True,
|
48 |
+
gpu_assign_thr=-1,
|
49 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
50 |
+
self.pos_iou_thr = pos_iou_thr
|
51 |
+
self.neg_iou_thr = neg_iou_thr
|
52 |
+
self.min_pos_iou = min_pos_iou
|
53 |
+
self.gt_max_assign_all = gt_max_assign_all
|
54 |
+
self.ignore_iof_thr = ignore_iof_thr
|
55 |
+
self.ignore_wrt_candidates = ignore_wrt_candidates
|
56 |
+
self.gpu_assign_thr = gpu_assign_thr
|
57 |
+
self.match_low_quality = match_low_quality
|
58 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
59 |
+
|
60 |
+
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
61 |
+
"""Assign gt to bboxes.
|
62 |
+
|
63 |
+
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
|
64 |
+
will be assigned with -1, or a semi-positive number. -1 means negative
|
65 |
+
sample, semi-positive number is the index (0-based) of assigned gt.
|
66 |
+
The assignment is done in following steps, the order matters.
|
67 |
+
|
68 |
+
1. assign every bbox to the background
|
69 |
+
2. assign proposals whose iou with all gts < neg_iou_thr to 0
|
70 |
+
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
|
71 |
+
assign it to that bbox
|
72 |
+
4. for each gt bbox, assign its nearest proposals (may be more than
|
73 |
+
one) to itself
|
74 |
+
|
75 |
+
Args:
|
76 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
77 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
78 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
79 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
80 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
:obj:`AssignResult`: The assign result.
|
84 |
+
|
85 |
+
Example:
|
86 |
+
>>> self = MaxIoUAssigner(0.5, 0.5)
|
87 |
+
>>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
|
88 |
+
>>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
|
89 |
+
>>> assign_result = self.assign(bboxes, gt_bboxes)
|
90 |
+
>>> expected_gt_inds = torch.LongTensor([1, 0])
|
91 |
+
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
|
92 |
+
"""
|
93 |
+
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
|
94 |
+
gt_bboxes.shape[0] > self.gpu_assign_thr) else False
|
95 |
+
# compute overlap and assign gt on CPU when number of GT is large
|
96 |
+
if assign_on_cpu:
|
97 |
+
device = bboxes.device
|
98 |
+
bboxes = bboxes.cpu()
|
99 |
+
gt_bboxes = gt_bboxes.cpu()
|
100 |
+
if gt_bboxes_ignore is not None:
|
101 |
+
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
|
102 |
+
if gt_labels is not None:
|
103 |
+
gt_labels = gt_labels.cpu()
|
104 |
+
|
105 |
+
overlaps = self.iou_calculator(gt_bboxes, bboxes)
|
106 |
+
|
107 |
+
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
|
108 |
+
and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
|
109 |
+
if self.ignore_wrt_candidates:
|
110 |
+
ignore_overlaps = self.iou_calculator(
|
111 |
+
bboxes, gt_bboxes_ignore, mode='iof')
|
112 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
|
113 |
+
else:
|
114 |
+
ignore_overlaps = self.iou_calculator(
|
115 |
+
gt_bboxes_ignore, bboxes, mode='iof')
|
116 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
|
117 |
+
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
|
118 |
+
|
119 |
+
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
|
120 |
+
if assign_on_cpu:
|
121 |
+
assign_result.gt_inds = assign_result.gt_inds.to(device)
|
122 |
+
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
|
123 |
+
if assign_result.labels is not None:
|
124 |
+
assign_result.labels = assign_result.labels.to(device)
|
125 |
+
return assign_result
|
126 |
+
|
127 |
+
def assign_wrt_overlaps(self, overlaps, gt_labels=None):
|
128 |
+
"""Assign w.r.t. the overlaps of bboxes with gts.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
|
132 |
+
shape(k, n).
|
133 |
+
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
:obj:`AssignResult`: The assign result.
|
137 |
+
"""
|
138 |
+
num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
|
139 |
+
|
140 |
+
# 1. assign -1 by default
|
141 |
+
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
|
142 |
+
-1,
|
143 |
+
dtype=torch.long)
|
144 |
+
|
145 |
+
if num_gts == 0 or num_bboxes == 0:
|
146 |
+
# No ground truth or boxes, return empty assignment
|
147 |
+
max_overlaps = overlaps.new_zeros((num_bboxes, ))
|
148 |
+
if num_gts == 0:
|
149 |
+
# No truth, assign everything to background
|
150 |
+
assigned_gt_inds[:] = 0
|
151 |
+
if gt_labels is None:
|
152 |
+
assigned_labels = None
|
153 |
+
else:
|
154 |
+
assigned_labels = overlaps.new_full((num_bboxes, ),
|
155 |
+
-1,
|
156 |
+
dtype=torch.long)
|
157 |
+
return AssignResult(
|
158 |
+
num_gts,
|
159 |
+
assigned_gt_inds,
|
160 |
+
max_overlaps,
|
161 |
+
labels=assigned_labels)
|
162 |
+
|
163 |
+
# for each anchor, which gt best overlaps with it
|
164 |
+
# for each anchor, the max iou of all gts
|
165 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
166 |
+
# for each gt, which anchor best overlaps with it
|
167 |
+
# for each gt, the max iou of all proposals
|
168 |
+
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
|
169 |
+
|
170 |
+
# 2. assign negative: below
|
171 |
+
# the negative inds are set to be 0
|
172 |
+
if isinstance(self.neg_iou_thr, float):
|
173 |
+
assigned_gt_inds[(max_overlaps >= 0)
|
174 |
+
& (max_overlaps < self.neg_iou_thr)] = 0
|
175 |
+
elif isinstance(self.neg_iou_thr, tuple):
|
176 |
+
assert len(self.neg_iou_thr) == 2
|
177 |
+
assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
|
178 |
+
& (max_overlaps < self.neg_iou_thr[1])] = 0
|
179 |
+
|
180 |
+
# 3. assign positive: above positive IoU threshold
|
181 |
+
pos_inds = max_overlaps >= self.pos_iou_thr
|
182 |
+
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
|
183 |
+
|
184 |
+
if self.match_low_quality:
|
185 |
+
# Low-quality matching will overwrite the assigned_gt_inds assigned
|
186 |
+
# in Step 3. Thus, the assigned gt might not be the best one for
|
187 |
+
# prediction.
|
188 |
+
# For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
|
189 |
+
# bbox 1 will be assigned as the best target for bbox A in step 3.
|
190 |
+
# However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
|
191 |
+
# assigned_gt_inds will be overwritten to be bbox B.
|
192 |
+
# This might be the reason that it is not used in ROI Heads.
|
193 |
+
for i in range(num_gts):
|
194 |
+
if gt_max_overlaps[i] >= self.min_pos_iou:
|
195 |
+
if self.gt_max_assign_all:
|
196 |
+
max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
|
197 |
+
assigned_gt_inds[max_iou_inds] = i + 1
|
198 |
+
else:
|
199 |
+
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
|
200 |
+
|
201 |
+
if gt_labels is not None:
|
202 |
+
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
|
203 |
+
pos_inds = torch.nonzero(
|
204 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
205 |
+
if pos_inds.numel() > 0:
|
206 |
+
assigned_labels[pos_inds] = gt_labels[
|
207 |
+
assigned_gt_inds[pos_inds] - 1]
|
208 |
+
else:
|
209 |
+
assigned_labels = None
|
210 |
+
|
211 |
+
return AssignResult(
|
212 |
+
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
mmdet/core/bbox/assigners/point_assigner.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from .assign_result import AssignResult
|
5 |
+
from .base_assigner import BaseAssigner
|
6 |
+
|
7 |
+
|
8 |
+
@BBOX_ASSIGNERS.register_module()
|
9 |
+
class PointAssigner(BaseAssigner):
|
10 |
+
"""Assign a corresponding gt bbox or background to each point.
|
11 |
+
|
12 |
+
Each proposals will be assigned with `0`, or a positive integer
|
13 |
+
indicating the ground truth index.
|
14 |
+
|
15 |
+
- 0: negative sample, no assigned gt
|
16 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, scale=4, pos_num=3):
|
20 |
+
self.scale = scale
|
21 |
+
self.pos_num = pos_num
|
22 |
+
|
23 |
+
def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
24 |
+
"""Assign gt to points.
|
25 |
+
|
26 |
+
This method assign a gt bbox to every points set, each points set
|
27 |
+
will be assigned with the background_label (-1), or a label number.
|
28 |
+
-1 is background, and semi-positive number is the index (0-based) of
|
29 |
+
assigned gt.
|
30 |
+
The assignment is done in following steps, the order matters.
|
31 |
+
|
32 |
+
1. assign every points to the background_label (-1)
|
33 |
+
2. A point is assigned to some gt bbox if
|
34 |
+
(i) the point is within the k closest points to the gt bbox
|
35 |
+
(ii) the distance between this point and the gt is smaller than
|
36 |
+
other gt bboxes
|
37 |
+
|
38 |
+
Args:
|
39 |
+
points (Tensor): points to be assigned, shape(n, 3) while last
|
40 |
+
dimension stands for (x, y, stride).
|
41 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
42 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
43 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
44 |
+
NOTE: currently unused.
|
45 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
:obj:`AssignResult`: The assign result.
|
49 |
+
"""
|
50 |
+
num_points = points.shape[0]
|
51 |
+
num_gts = gt_bboxes.shape[0]
|
52 |
+
|
53 |
+
if num_gts == 0 or num_points == 0:
|
54 |
+
# If no truth assign everything to the background
|
55 |
+
assigned_gt_inds = points.new_full((num_points, ),
|
56 |
+
0,
|
57 |
+
dtype=torch.long)
|
58 |
+
if gt_labels is None:
|
59 |
+
assigned_labels = None
|
60 |
+
else:
|
61 |
+
assigned_labels = points.new_full((num_points, ),
|
62 |
+
-1,
|
63 |
+
dtype=torch.long)
|
64 |
+
return AssignResult(
|
65 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
66 |
+
|
67 |
+
points_xy = points[:, :2]
|
68 |
+
points_stride = points[:, 2]
|
69 |
+
points_lvl = torch.log2(
|
70 |
+
points_stride).int() # [3...,4...,5...,6...,7...]
|
71 |
+
lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
|
72 |
+
|
73 |
+
# assign gt box
|
74 |
+
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
|
75 |
+
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
|
76 |
+
scale = self.scale
|
77 |
+
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
|
78 |
+
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
|
79 |
+
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
|
80 |
+
|
81 |
+
# stores the assigned gt index of each point
|
82 |
+
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
|
83 |
+
# stores the assigned gt dist (to this point) of each point
|
84 |
+
assigned_gt_dist = points.new_full((num_points, ), float('inf'))
|
85 |
+
points_range = torch.arange(points.shape[0])
|
86 |
+
|
87 |
+
for idx in range(num_gts):
|
88 |
+
gt_lvl = gt_bboxes_lvl[idx]
|
89 |
+
# get the index of points in this level
|
90 |
+
lvl_idx = gt_lvl == points_lvl
|
91 |
+
points_index = points_range[lvl_idx]
|
92 |
+
# get the points in this level
|
93 |
+
lvl_points = points_xy[lvl_idx, :]
|
94 |
+
# get the center point of gt
|
95 |
+
gt_point = gt_bboxes_xy[[idx], :]
|
96 |
+
# get width and height of gt
|
97 |
+
gt_wh = gt_bboxes_wh[[idx], :]
|
98 |
+
# compute the distance between gt center and
|
99 |
+
# all points in this level
|
100 |
+
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
|
101 |
+
# find the nearest k points to gt center in this level
|
102 |
+
min_dist, min_dist_index = torch.topk(
|
103 |
+
points_gt_dist, self.pos_num, largest=False)
|
104 |
+
# the index of nearest k points to gt center in this level
|
105 |
+
min_dist_points_index = points_index[min_dist_index]
|
106 |
+
# The less_than_recorded_index stores the index
|
107 |
+
# of min_dist that is less then the assigned_gt_dist. Where
|
108 |
+
# assigned_gt_dist stores the dist from previous assigned gt
|
109 |
+
# (if exist) to each point.
|
110 |
+
less_than_recorded_index = min_dist < assigned_gt_dist[
|
111 |
+
min_dist_points_index]
|
112 |
+
# The min_dist_points_index stores the index of points satisfy:
|
113 |
+
# (1) it is k nearest to current gt center in this level.
|
114 |
+
# (2) it is closer to current gt center than other gt center.
|
115 |
+
min_dist_points_index = min_dist_points_index[
|
116 |
+
less_than_recorded_index]
|
117 |
+
# assign the result
|
118 |
+
assigned_gt_inds[min_dist_points_index] = idx + 1
|
119 |
+
assigned_gt_dist[min_dist_points_index] = min_dist[
|
120 |
+
less_than_recorded_index]
|
121 |
+
|
122 |
+
if gt_labels is not None:
|
123 |
+
assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
|
124 |
+
pos_inds = torch.nonzero(
|
125 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
126 |
+
if pos_inds.numel() > 0:
|
127 |
+
assigned_labels[pos_inds] = gt_labels[
|
128 |
+
assigned_gt_inds[pos_inds] - 1]
|
129 |
+
else:
|
130 |
+
assigned_labels = None
|
131 |
+
|
132 |
+
return AssignResult(
|
133 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
mmdet/core/bbox/assigners/region_assigner.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from mmdet.core import anchor_inside_flags
|
4 |
+
from ..builder import BBOX_ASSIGNERS
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
def calc_region(bbox, ratio, stride, featmap_size=None):
|
10 |
+
"""Calculate region of the box defined by the ratio, the ratio is from the
|
11 |
+
center of the box to every edge."""
|
12 |
+
# project bbox on the feature
|
13 |
+
f_bbox = bbox / stride
|
14 |
+
x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
|
15 |
+
y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
|
16 |
+
x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
|
17 |
+
y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
|
18 |
+
if featmap_size is not None:
|
19 |
+
x1 = x1.clamp(min=0, max=featmap_size[1])
|
20 |
+
y1 = y1.clamp(min=0, max=featmap_size[0])
|
21 |
+
x2 = x2.clamp(min=0, max=featmap_size[1])
|
22 |
+
y2 = y2.clamp(min=0, max=featmap_size[0])
|
23 |
+
return (x1, y1, x2, y2)
|
24 |
+
|
25 |
+
|
26 |
+
def anchor_ctr_inside_region_flags(anchors, stride, region):
|
27 |
+
"""Get the flag indicate whether anchor centers are inside regions."""
|
28 |
+
x1, y1, x2, y2 = region
|
29 |
+
f_anchors = anchors / stride
|
30 |
+
x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
|
31 |
+
y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
|
32 |
+
flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
|
33 |
+
return flags
|
34 |
+
|
35 |
+
|
36 |
+
@BBOX_ASSIGNERS.register_module()
|
37 |
+
class RegionAssigner(BaseAssigner):
|
38 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
39 |
+
|
40 |
+
Each proposals will be assigned with `-1`, `0`, or a positive integer
|
41 |
+
indicating the ground truth index.
|
42 |
+
|
43 |
+
- -1: don't care
|
44 |
+
- 0: negative sample, no assigned gt
|
45 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
46 |
+
|
47 |
+
Args:
|
48 |
+
center_ratio: ratio of the region in the center of the bbox to
|
49 |
+
define positive sample.
|
50 |
+
ignore_ratio: ratio of the region to define ignore samples.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
|
54 |
+
self.center_ratio = center_ratio
|
55 |
+
self.ignore_ratio = ignore_ratio
|
56 |
+
|
57 |
+
def assign(self,
|
58 |
+
mlvl_anchors,
|
59 |
+
mlvl_valid_flags,
|
60 |
+
gt_bboxes,
|
61 |
+
img_meta,
|
62 |
+
featmap_sizes,
|
63 |
+
anchor_scale,
|
64 |
+
anchor_strides,
|
65 |
+
gt_bboxes_ignore=None,
|
66 |
+
gt_labels=None,
|
67 |
+
allowed_border=0):
|
68 |
+
"""Assign gt to anchors.
|
69 |
+
|
70 |
+
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
|
71 |
+
will be assigned with -1, 0, or a positive number. -1 means don't care,
|
72 |
+
0 means negative sample, positive number is the index (1-based) of
|
73 |
+
assigned gt.
|
74 |
+
The assignment is done in following steps, the order matters.
|
75 |
+
|
76 |
+
1. Assign every anchor to 0 (negative)
|
77 |
+
For each gt_bboxes:
|
78 |
+
2. Compute ignore flags based on ignore_region then
|
79 |
+
assign -1 to anchors w.r.t. ignore flags
|
80 |
+
3. Compute pos flags based on center_region then
|
81 |
+
assign gt_bboxes to anchors w.r.t. pos flags
|
82 |
+
4. Compute ignore flags based on adjacent anchor lvl then
|
83 |
+
assign -1 to anchors w.r.t. ignore flags
|
84 |
+
5. Assign anchor outside of image to -1
|
85 |
+
|
86 |
+
Args:
|
87 |
+
mlvl_anchors (list[Tensor]): Multi level anchors.
|
88 |
+
mlvl_valid_flags (list[Tensor]): Multi level valid flags.
|
89 |
+
gt_bboxes (Tensor): Ground truth bboxes of image
|
90 |
+
img_meta (dict): Meta info of image.
|
91 |
+
featmap_sizes (list[Tensor]): Feature mapsize each level
|
92 |
+
anchor_scale (int): Scale of the anchor.
|
93 |
+
anchor_strides (list[int]): Stride of the anchor.
|
94 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
95 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
96 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
97 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
98 |
+
allowed_border (int, optional): The border to allow the valid
|
99 |
+
anchor. Defaults to 0.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
:obj:`AssignResult`: The assign result.
|
103 |
+
"""
|
104 |
+
if gt_bboxes_ignore is not None:
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
num_gts = gt_bboxes.shape[0]
|
108 |
+
num_bboxes = sum(x.shape[0] for x in mlvl_anchors)
|
109 |
+
|
110 |
+
if num_gts == 0 or num_bboxes == 0:
|
111 |
+
# No ground truth or boxes, return empty assignment
|
112 |
+
max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
|
113 |
+
assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
|
114 |
+
dtype=torch.long)
|
115 |
+
if gt_labels is None:
|
116 |
+
assigned_labels = None
|
117 |
+
else:
|
118 |
+
assigned_labels = gt_bboxes.new_full((num_bboxes, ),
|
119 |
+
-1,
|
120 |
+
dtype=torch.long)
|
121 |
+
return AssignResult(
|
122 |
+
num_gts,
|
123 |
+
assigned_gt_inds,
|
124 |
+
max_overlaps,
|
125 |
+
labels=assigned_labels)
|
126 |
+
|
127 |
+
num_lvls = len(mlvl_anchors)
|
128 |
+
r1 = (1 - self.center_ratio) / 2
|
129 |
+
r2 = (1 - self.ignore_ratio) / 2
|
130 |
+
|
131 |
+
scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
|
132 |
+
(gt_bboxes[:, 3] - gt_bboxes[:, 1]))
|
133 |
+
min_anchor_size = scale.new_full(
|
134 |
+
(1, ), float(anchor_scale * anchor_strides[0]))
|
135 |
+
target_lvls = torch.floor(
|
136 |
+
torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
|
137 |
+
target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
|
138 |
+
|
139 |
+
# 1. assign 0 (negative) by default
|
140 |
+
mlvl_assigned_gt_inds = []
|
141 |
+
mlvl_ignore_flags = []
|
142 |
+
for lvl in range(num_lvls):
|
143 |
+
h, w = featmap_sizes[lvl]
|
144 |
+
assert h * w == mlvl_anchors[lvl].shape[0]
|
145 |
+
assigned_gt_inds = gt_bboxes.new_full((h * w, ),
|
146 |
+
0,
|
147 |
+
dtype=torch.long)
|
148 |
+
ignore_flags = torch.zeros_like(assigned_gt_inds)
|
149 |
+
mlvl_assigned_gt_inds.append(assigned_gt_inds)
|
150 |
+
mlvl_ignore_flags.append(ignore_flags)
|
151 |
+
|
152 |
+
for gt_id in range(num_gts):
|
153 |
+
lvl = target_lvls[gt_id].item()
|
154 |
+
featmap_size = featmap_sizes[lvl]
|
155 |
+
stride = anchor_strides[lvl]
|
156 |
+
anchors = mlvl_anchors[lvl]
|
157 |
+
gt_bbox = gt_bboxes[gt_id, :4]
|
158 |
+
|
159 |
+
# Compute regions
|
160 |
+
ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
|
161 |
+
ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
|
162 |
+
|
163 |
+
# 2. Assign -1 to ignore flags
|
164 |
+
ignore_flags = anchor_ctr_inside_region_flags(
|
165 |
+
anchors, stride, ignore_region)
|
166 |
+
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
|
167 |
+
|
168 |
+
# 3. Assign gt_bboxes to pos flags
|
169 |
+
pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
|
170 |
+
ctr_region)
|
171 |
+
mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
|
172 |
+
|
173 |
+
# 4. Assign -1 to ignore adjacent lvl
|
174 |
+
if lvl > 0:
|
175 |
+
d_lvl = lvl - 1
|
176 |
+
d_anchors = mlvl_anchors[d_lvl]
|
177 |
+
d_featmap_size = featmap_sizes[d_lvl]
|
178 |
+
d_stride = anchor_strides[d_lvl]
|
179 |
+
d_ignore_region = calc_region(gt_bbox, r2, d_stride,
|
180 |
+
d_featmap_size)
|
181 |
+
ignore_flags = anchor_ctr_inside_region_flags(
|
182 |
+
d_anchors, d_stride, d_ignore_region)
|
183 |
+
mlvl_ignore_flags[d_lvl][ignore_flags] = 1
|
184 |
+
if lvl < num_lvls - 1:
|
185 |
+
u_lvl = lvl + 1
|
186 |
+
u_anchors = mlvl_anchors[u_lvl]
|
187 |
+
u_featmap_size = featmap_sizes[u_lvl]
|
188 |
+
u_stride = anchor_strides[u_lvl]
|
189 |
+
u_ignore_region = calc_region(gt_bbox, r2, u_stride,
|
190 |
+
u_featmap_size)
|
191 |
+
ignore_flags = anchor_ctr_inside_region_flags(
|
192 |
+
u_anchors, u_stride, u_ignore_region)
|
193 |
+
mlvl_ignore_flags[u_lvl][ignore_flags] = 1
|
194 |
+
|
195 |
+
# 4. (cont.) Assign -1 to ignore adjacent lvl
|
196 |
+
for lvl in range(num_lvls):
|
197 |
+
ignore_flags = mlvl_ignore_flags[lvl]
|
198 |
+
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
|
199 |
+
|
200 |
+
# 5. Assign -1 to anchor outside of image
|
201 |
+
flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
|
202 |
+
flat_anchors = torch.cat(mlvl_anchors)
|
203 |
+
flat_valid_flags = torch.cat(mlvl_valid_flags)
|
204 |
+
assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
|
205 |
+
flat_valid_flags.shape[0])
|
206 |
+
inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
|
207 |
+
img_meta['img_shape'],
|
208 |
+
allowed_border)
|
209 |
+
outside_flags = ~inside_flags
|
210 |
+
flat_assigned_gt_inds[outside_flags] = -1
|
211 |
+
|
212 |
+
if gt_labels is not None:
|
213 |
+
assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
|
214 |
+
pos_flags = assigned_gt_inds > 0
|
215 |
+
assigned_labels[pos_flags] = gt_labels[
|
216 |
+
flat_assigned_gt_inds[pos_flags] - 1]
|
217 |
+
else:
|
218 |
+
assigned_labels = None
|
219 |
+
|
220 |
+
return AssignResult(
|
221 |
+
num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
|
mmdet/core/bbox/builder.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.utils import Registry, build_from_cfg
|
2 |
+
|
3 |
+
BBOX_ASSIGNERS = Registry('bbox_assigner')
|
4 |
+
BBOX_SAMPLERS = Registry('bbox_sampler')
|
5 |
+
BBOX_CODERS = Registry('bbox_coder')
|
6 |
+
|
7 |
+
|
8 |
+
def build_assigner(cfg, **default_args):
|
9 |
+
"""Builder of box assigner."""
|
10 |
+
return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
|
11 |
+
|
12 |
+
|
13 |
+
def build_sampler(cfg, **default_args):
|
14 |
+
"""Builder of box sampler."""
|
15 |
+
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
|
16 |
+
|
17 |
+
|
18 |
+
def build_bbox_coder(cfg, **default_args):
|
19 |
+
"""Builder of box coder."""
|
20 |
+
return build_from_cfg(cfg, BBOX_CODERS, default_args)
|
mmdet/core/bbox/coder/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_bbox_coder import BaseBBoxCoder
|
2 |
+
from .bucketing_bbox_coder import BucketingBBoxCoder
|
3 |
+
from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
|
4 |
+
from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
|
5 |
+
from .pseudo_bbox_coder import PseudoBBoxCoder
|
6 |
+
from .tblr_bbox_coder import TBLRBBoxCoder
|
7 |
+
from .yolo_bbox_coder import YOLOBBoxCoder
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
|
11 |
+
'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
|
12 |
+
'BucketingBBoxCoder'
|
13 |
+
]
|
mmdet/core/bbox/coder/base_bbox_coder.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class BaseBBoxCoder(metaclass=ABCMeta):
|
5 |
+
"""Base bounding box coder."""
|
6 |
+
|
7 |
+
def __init__(self, **kwargs):
|
8 |
+
pass
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def encode(self, bboxes, gt_bboxes):
|
12 |
+
"""Encode deltas between bboxes and ground truth boxes."""
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def decode(self, bboxes, bboxes_pred):
|
16 |
+
"""Decode the predicted bboxes according to prediction and base
|
17 |
+
boxes."""
|
mmdet/core/bbox/coder/bucketing_bbox_coder.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..builder import BBOX_CODERS
|
7 |
+
from ..transforms import bbox_rescale
|
8 |
+
from .base_bbox_coder import BaseBBoxCoder
|
9 |
+
|
10 |
+
|
11 |
+
@BBOX_CODERS.register_module()
|
12 |
+
class BucketingBBoxCoder(BaseBBoxCoder):
|
13 |
+
"""Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
|
14 |
+
|
15 |
+
Boundary Localization with Bucketing and Bucketing Guided Rescoring
|
16 |
+
are implemented here.
|
17 |
+
|
18 |
+
Please refer to https://arxiv.org/abs/1912.04260 for more details.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
num_buckets (int): Number of buckets.
|
22 |
+
scale_factor (int): Scale factor of proposals to generate buckets.
|
23 |
+
offset_topk (int): Topk buckets are used to generate
|
24 |
+
bucket fine regression targets. Defaults to 2.
|
25 |
+
offset_upperbound (float): Offset upperbound to generate
|
26 |
+
bucket fine regression targets.
|
27 |
+
To avoid too large offset displacements. Defaults to 1.0.
|
28 |
+
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
|
29 |
+
Defaults to True.
|
30 |
+
clip_border (bool, optional): Whether clip the objects outside the
|
31 |
+
border of the image. Defaults to True.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
num_buckets,
|
36 |
+
scale_factor,
|
37 |
+
offset_topk=2,
|
38 |
+
offset_upperbound=1.0,
|
39 |
+
cls_ignore_neighbor=True,
|
40 |
+
clip_border=True):
|
41 |
+
super(BucketingBBoxCoder, self).__init__()
|
42 |
+
self.num_buckets = num_buckets
|
43 |
+
self.scale_factor = scale_factor
|
44 |
+
self.offset_topk = offset_topk
|
45 |
+
self.offset_upperbound = offset_upperbound
|
46 |
+
self.cls_ignore_neighbor = cls_ignore_neighbor
|
47 |
+
self.clip_border = clip_border
|
48 |
+
|
49 |
+
def encode(self, bboxes, gt_bboxes):
|
50 |
+
"""Get bucketing estimation and fine regression targets during
|
51 |
+
training.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
bboxes (torch.Tensor): source boxes, e.g., object proposals.
|
55 |
+
gt_bboxes (torch.Tensor): target of the transformation, e.g.,
|
56 |
+
ground truth boxes.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
encoded_bboxes(tuple[Tensor]): bucketing estimation
|
60 |
+
and fine regression targets and weights
|
61 |
+
"""
|
62 |
+
|
63 |
+
assert bboxes.size(0) == gt_bboxes.size(0)
|
64 |
+
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
|
65 |
+
encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
|
66 |
+
self.scale_factor, self.offset_topk,
|
67 |
+
self.offset_upperbound,
|
68 |
+
self.cls_ignore_neighbor)
|
69 |
+
return encoded_bboxes
|
70 |
+
|
71 |
+
def decode(self, bboxes, pred_bboxes, max_shape=None):
|
72 |
+
"""Apply transformation `pred_bboxes` to `boxes`.
|
73 |
+
Args:
|
74 |
+
boxes (torch.Tensor): Basic boxes.
|
75 |
+
pred_bboxes (torch.Tensor): Predictions for bucketing estimation
|
76 |
+
and fine regression
|
77 |
+
max_shape (tuple[int], optional): Maximum shape of boxes.
|
78 |
+
Defaults to None.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
torch.Tensor: Decoded boxes.
|
82 |
+
"""
|
83 |
+
assert len(pred_bboxes) == 2
|
84 |
+
cls_preds, offset_preds = pred_bboxes
|
85 |
+
assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
|
86 |
+
0) == bboxes.size(0)
|
87 |
+
decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
|
88 |
+
self.num_buckets, self.scale_factor,
|
89 |
+
max_shape, self.clip_border)
|
90 |
+
|
91 |
+
return decoded_bboxes
|
92 |
+
|
93 |
+
|
94 |
+
@mmcv.jit(coderize=True)
|
95 |
+
def generat_buckets(proposals, num_buckets, scale_factor=1.0):
|
96 |
+
"""Generate buckets w.r.t bucket number and scale factor of proposals.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
proposals (Tensor): Shape (n, 4)
|
100 |
+
num_buckets (int): Number of buckets.
|
101 |
+
scale_factor (float): Scale factor to rescale proposals.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
|
105 |
+
t_buckets, d_buckets)
|
106 |
+
|
107 |
+
- bucket_w: Width of buckets on x-axis. Shape (n, ).
|
108 |
+
- bucket_h: Height of buckets on y-axis. Shape (n, ).
|
109 |
+
- l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
|
110 |
+
- r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
|
111 |
+
- t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
|
112 |
+
- d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
|
113 |
+
"""
|
114 |
+
proposals = bbox_rescale(proposals, scale_factor)
|
115 |
+
|
116 |
+
# number of buckets in each side
|
117 |
+
side_num = int(np.ceil(num_buckets / 2.0))
|
118 |
+
pw = proposals[..., 2] - proposals[..., 0]
|
119 |
+
ph = proposals[..., 3] - proposals[..., 1]
|
120 |
+
px1 = proposals[..., 0]
|
121 |
+
py1 = proposals[..., 1]
|
122 |
+
px2 = proposals[..., 2]
|
123 |
+
py2 = proposals[..., 3]
|
124 |
+
|
125 |
+
bucket_w = pw / num_buckets
|
126 |
+
bucket_h = ph / num_buckets
|
127 |
+
|
128 |
+
# left buckets
|
129 |
+
l_buckets = px1[:, None] + (0.5 + torch.arange(
|
130 |
+
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
|
131 |
+
# right buckets
|
132 |
+
r_buckets = px2[:, None] - (0.5 + torch.arange(
|
133 |
+
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
|
134 |
+
# top buckets
|
135 |
+
t_buckets = py1[:, None] + (0.5 + torch.arange(
|
136 |
+
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
|
137 |
+
# down buckets
|
138 |
+
d_buckets = py2[:, None] - (0.5 + torch.arange(
|
139 |
+
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
|
140 |
+
return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
|
141 |
+
|
142 |
+
|
143 |
+
@mmcv.jit(coderize=True)
|
144 |
+
def bbox2bucket(proposals,
|
145 |
+
gt,
|
146 |
+
num_buckets,
|
147 |
+
scale_factor,
|
148 |
+
offset_topk=2,
|
149 |
+
offset_upperbound=1.0,
|
150 |
+
cls_ignore_neighbor=True):
|
151 |
+
"""Generate buckets estimation and fine regression targets.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
proposals (Tensor): Shape (n, 4)
|
155 |
+
gt (Tensor): Shape (n, 4)
|
156 |
+
num_buckets (int): Number of buckets.
|
157 |
+
scale_factor (float): Scale factor to rescale proposals.
|
158 |
+
offset_topk (int): Topk buckets are used to generate
|
159 |
+
bucket fine regression targets. Defaults to 2.
|
160 |
+
offset_upperbound (float): Offset allowance to generate
|
161 |
+
bucket fine regression targets.
|
162 |
+
To avoid too large offset displacements. Defaults to 1.0.
|
163 |
+
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
|
164 |
+
Defaults to True.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
|
168 |
+
|
169 |
+
- offsets: Fine regression targets. \
|
170 |
+
Shape (n, num_buckets*2).
|
171 |
+
- offsets_weights: Fine regression weights. \
|
172 |
+
Shape (n, num_buckets*2).
|
173 |
+
- bucket_labels: Bucketing estimation labels. \
|
174 |
+
Shape (n, num_buckets*2).
|
175 |
+
- cls_weights: Bucketing estimation weights. \
|
176 |
+
Shape (n, num_buckets*2).
|
177 |
+
"""
|
178 |
+
assert proposals.size() == gt.size()
|
179 |
+
|
180 |
+
# generate buckets
|
181 |
+
proposals = proposals.float()
|
182 |
+
gt = gt.float()
|
183 |
+
(bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
|
184 |
+
d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
|
185 |
+
|
186 |
+
gx1 = gt[..., 0]
|
187 |
+
gy1 = gt[..., 1]
|
188 |
+
gx2 = gt[..., 2]
|
189 |
+
gy2 = gt[..., 3]
|
190 |
+
|
191 |
+
# generate offset targets and weights
|
192 |
+
# offsets from buckets to gts
|
193 |
+
l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
|
194 |
+
r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
|
195 |
+
t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
|
196 |
+
d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
|
197 |
+
|
198 |
+
# select top-k nearset buckets
|
199 |
+
l_topk, l_label = l_offsets.abs().topk(
|
200 |
+
offset_topk, dim=1, largest=False, sorted=True)
|
201 |
+
r_topk, r_label = r_offsets.abs().topk(
|
202 |
+
offset_topk, dim=1, largest=False, sorted=True)
|
203 |
+
t_topk, t_label = t_offsets.abs().topk(
|
204 |
+
offset_topk, dim=1, largest=False, sorted=True)
|
205 |
+
d_topk, d_label = d_offsets.abs().topk(
|
206 |
+
offset_topk, dim=1, largest=False, sorted=True)
|
207 |
+
|
208 |
+
offset_l_weights = l_offsets.new_zeros(l_offsets.size())
|
209 |
+
offset_r_weights = r_offsets.new_zeros(r_offsets.size())
|
210 |
+
offset_t_weights = t_offsets.new_zeros(t_offsets.size())
|
211 |
+
offset_d_weights = d_offsets.new_zeros(d_offsets.size())
|
212 |
+
inds = torch.arange(0, proposals.size(0)).to(proposals).long()
|
213 |
+
|
214 |
+
# generate offset weights of top-k nearset buckets
|
215 |
+
for k in range(offset_topk):
|
216 |
+
if k >= 1:
|
217 |
+
offset_l_weights[inds, l_label[:,
|
218 |
+
k]] = (l_topk[:, k] <
|
219 |
+
offset_upperbound).float()
|
220 |
+
offset_r_weights[inds, r_label[:,
|
221 |
+
k]] = (r_topk[:, k] <
|
222 |
+
offset_upperbound).float()
|
223 |
+
offset_t_weights[inds, t_label[:,
|
224 |
+
k]] = (t_topk[:, k] <
|
225 |
+
offset_upperbound).float()
|
226 |
+
offset_d_weights[inds, d_label[:,
|
227 |
+
k]] = (d_topk[:, k] <
|
228 |
+
offset_upperbound).float()
|
229 |
+
else:
|
230 |
+
offset_l_weights[inds, l_label[:, k]] = 1.0
|
231 |
+
offset_r_weights[inds, r_label[:, k]] = 1.0
|
232 |
+
offset_t_weights[inds, t_label[:, k]] = 1.0
|
233 |
+
offset_d_weights[inds, d_label[:, k]] = 1.0
|
234 |
+
|
235 |
+
offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
|
236 |
+
offsets_weights = torch.cat([
|
237 |
+
offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
|
238 |
+
],
|
239 |
+
dim=-1)
|
240 |
+
|
241 |
+
# generate bucket labels and weight
|
242 |
+
side_num = int(np.ceil(num_buckets / 2.0))
|
243 |
+
labels = torch.stack(
|
244 |
+
[l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
|
245 |
+
|
246 |
+
batch_size = labels.size(0)
|
247 |
+
bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
|
248 |
+
-1).float()
|
249 |
+
bucket_cls_l_weights = (l_offsets.abs() < 1).float()
|
250 |
+
bucket_cls_r_weights = (r_offsets.abs() < 1).float()
|
251 |
+
bucket_cls_t_weights = (t_offsets.abs() < 1).float()
|
252 |
+
bucket_cls_d_weights = (d_offsets.abs() < 1).float()
|
253 |
+
bucket_cls_weights = torch.cat([
|
254 |
+
bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
|
255 |
+
bucket_cls_d_weights
|
256 |
+
],
|
257 |
+
dim=-1)
|
258 |
+
# ignore second nearest buckets for cls if necessary
|
259 |
+
if cls_ignore_neighbor:
|
260 |
+
bucket_cls_weights = (~((bucket_cls_weights == 1) &
|
261 |
+
(bucket_labels == 0))).float()
|
262 |
+
else:
|
263 |
+
bucket_cls_weights[:] = 1.0
|
264 |
+
return offsets, offsets_weights, bucket_labels, bucket_cls_weights
|
265 |
+
|
266 |
+
|
267 |
+
@mmcv.jit(coderize=True)
|
268 |
+
def bucket2bbox(proposals,
|
269 |
+
cls_preds,
|
270 |
+
offset_preds,
|
271 |
+
num_buckets,
|
272 |
+
scale_factor=1.0,
|
273 |
+
max_shape=None,
|
274 |
+
clip_border=True):
|
275 |
+
"""Apply bucketing estimation (cls preds) and fine regression (offset
|
276 |
+
preds) to generate det bboxes.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
proposals (Tensor): Boxes to be transformed. Shape (n, 4)
|
280 |
+
cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
|
281 |
+
offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
|
282 |
+
num_buckets (int): Number of buckets.
|
283 |
+
scale_factor (float): Scale factor to rescale proposals.
|
284 |
+
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
|
285 |
+
clip_border (bool, optional): Whether clip the objects outside the
|
286 |
+
border of the image. Defaults to True.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
tuple[Tensor]: (bboxes, loc_confidence).
|
290 |
+
|
291 |
+
- bboxes: predicted bboxes. Shape (n, 4)
|
292 |
+
- loc_confidence: localization confidence of predicted bboxes.
|
293 |
+
Shape (n,).
|
294 |
+
"""
|
295 |
+
|
296 |
+
side_num = int(np.ceil(num_buckets / 2.0))
|
297 |
+
cls_preds = cls_preds.view(-1, side_num)
|
298 |
+
offset_preds = offset_preds.view(-1, side_num)
|
299 |
+
|
300 |
+
scores = F.softmax(cls_preds, dim=1)
|
301 |
+
score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
|
302 |
+
|
303 |
+
rescaled_proposals = bbox_rescale(proposals, scale_factor)
|
304 |
+
|
305 |
+
pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
|
306 |
+
ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
|
307 |
+
px1 = rescaled_proposals[..., 0]
|
308 |
+
py1 = rescaled_proposals[..., 1]
|
309 |
+
px2 = rescaled_proposals[..., 2]
|
310 |
+
py2 = rescaled_proposals[..., 3]
|
311 |
+
|
312 |
+
bucket_w = pw / num_buckets
|
313 |
+
bucket_h = ph / num_buckets
|
314 |
+
|
315 |
+
score_inds_l = score_label[0::4, 0]
|
316 |
+
score_inds_r = score_label[1::4, 0]
|
317 |
+
score_inds_t = score_label[2::4, 0]
|
318 |
+
score_inds_d = score_label[3::4, 0]
|
319 |
+
l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
|
320 |
+
r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
|
321 |
+
t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
|
322 |
+
d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
|
323 |
+
|
324 |
+
offsets = offset_preds.view(-1, 4, side_num)
|
325 |
+
inds = torch.arange(proposals.size(0)).to(proposals).long()
|
326 |
+
l_offsets = offsets[:, 0, :][inds, score_inds_l]
|
327 |
+
r_offsets = offsets[:, 1, :][inds, score_inds_r]
|
328 |
+
t_offsets = offsets[:, 2, :][inds, score_inds_t]
|
329 |
+
d_offsets = offsets[:, 3, :][inds, score_inds_d]
|
330 |
+
|
331 |
+
x1 = l_buckets - l_offsets * bucket_w
|
332 |
+
x2 = r_buckets - r_offsets * bucket_w
|
333 |
+
y1 = t_buckets - t_offsets * bucket_h
|
334 |
+
y2 = d_buckets - d_offsets * bucket_h
|
335 |
+
|
336 |
+
if clip_border and max_shape is not None:
|
337 |
+
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
|
338 |
+
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
|
339 |
+
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
|
340 |
+
y2 = y2.clamp(min=0, max=max_shape[0] - 1)
|
341 |
+
bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
|
342 |
+
dim=-1)
|
343 |
+
|
344 |
+
# bucketing guided rescoring
|
345 |
+
loc_confidence = score_topk[:, 0]
|
346 |
+
top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
|
347 |
+
loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
|
348 |
+
loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
|
349 |
+
|
350 |
+
return bboxes, loc_confidence
|
mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ..builder import BBOX_CODERS
|
6 |
+
from .base_bbox_coder import BaseBBoxCoder
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_CODERS.register_module()
|
10 |
+
class DeltaXYWHBBoxCoder(BaseBBoxCoder):
|
11 |
+
"""Delta XYWH BBox coder.
|
12 |
+
|
13 |
+
Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
|
14 |
+
this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
|
15 |
+
decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
|
16 |
+
|
17 |
+
Args:
|
18 |
+
target_means (Sequence[float]): Denormalizing means of target for
|
19 |
+
delta coordinates
|
20 |
+
target_stds (Sequence[float]): Denormalizing standard deviation of
|
21 |
+
target for delta coordinates
|
22 |
+
clip_border (bool, optional): Whether clip the objects outside the
|
23 |
+
border of the image. Defaults to True.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
target_means=(0., 0., 0., 0.),
|
28 |
+
target_stds=(1., 1., 1., 1.),
|
29 |
+
clip_border=True):
|
30 |
+
super(BaseBBoxCoder, self).__init__()
|
31 |
+
self.means = target_means
|
32 |
+
self.stds = target_stds
|
33 |
+
self.clip_border = clip_border
|
34 |
+
|
35 |
+
def encode(self, bboxes, gt_bboxes):
|
36 |
+
"""Get box regression transformation deltas that can be used to
|
37 |
+
transform the ``bboxes`` into the ``gt_bboxes``.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
bboxes (torch.Tensor): Source boxes, e.g., object proposals.
|
41 |
+
gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
|
42 |
+
ground-truth boxes.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: Box transformation deltas
|
46 |
+
"""
|
47 |
+
|
48 |
+
assert bboxes.size(0) == gt_bboxes.size(0)
|
49 |
+
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
|
50 |
+
encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
|
51 |
+
return encoded_bboxes
|
52 |
+
|
53 |
+
def decode(self,
|
54 |
+
bboxes,
|
55 |
+
pred_bboxes,
|
56 |
+
max_shape=None,
|
57 |
+
wh_ratio_clip=16 / 1000):
|
58 |
+
"""Apply transformation `pred_bboxes` to `boxes`.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
|
62 |
+
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
|
63 |
+
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
64 |
+
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
65 |
+
when rois is a grid of anchors.Offset encoding follows [1]_.
|
66 |
+
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
67 |
+
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
68 |
+
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
|
69 |
+
the max_shape should be a Sequence[Sequence[int]]
|
70 |
+
and the length of max_shape should also be B.
|
71 |
+
wh_ratio_clip (float, optional): The allowed ratio between
|
72 |
+
width and height.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
torch.Tensor: Decoded boxes.
|
76 |
+
"""
|
77 |
+
|
78 |
+
assert pred_bboxes.size(0) == bboxes.size(0)
|
79 |
+
if pred_bboxes.ndim == 3:
|
80 |
+
assert pred_bboxes.size(1) == bboxes.size(1)
|
81 |
+
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
|
82 |
+
max_shape, wh_ratio_clip, self.clip_border)
|
83 |
+
|
84 |
+
return decoded_bboxes
|
85 |
+
|
86 |
+
|
87 |
+
@mmcv.jit(coderize=True)
|
88 |
+
def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
|
89 |
+
"""Compute deltas of proposals w.r.t. gt.
|
90 |
+
|
91 |
+
We usually compute the deltas of x, y, w, h of proposals w.r.t ground
|
92 |
+
truth bboxes to get regression target.
|
93 |
+
This is the inverse function of :func:`delta2bbox`.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
|
97 |
+
gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
|
98 |
+
means (Sequence[float]): Denormalizing means for delta coordinates
|
99 |
+
stds (Sequence[float]): Denormalizing standard deviation for delta
|
100 |
+
coordinates
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Tensor: deltas with shape (N, 4), where columns represent dx, dy,
|
104 |
+
dw, dh.
|
105 |
+
"""
|
106 |
+
assert proposals.size() == gt.size()
|
107 |
+
|
108 |
+
proposals = proposals.float()
|
109 |
+
gt = gt.float()
|
110 |
+
px = (proposals[..., 0] + proposals[..., 2]) * 0.5
|
111 |
+
py = (proposals[..., 1] + proposals[..., 3]) * 0.5
|
112 |
+
pw = proposals[..., 2] - proposals[..., 0]
|
113 |
+
ph = proposals[..., 3] - proposals[..., 1]
|
114 |
+
|
115 |
+
gx = (gt[..., 0] + gt[..., 2]) * 0.5
|
116 |
+
gy = (gt[..., 1] + gt[..., 3]) * 0.5
|
117 |
+
gw = gt[..., 2] - gt[..., 0]
|
118 |
+
gh = gt[..., 3] - gt[..., 1]
|
119 |
+
|
120 |
+
dx = (gx - px) / pw
|
121 |
+
dy = (gy - py) / ph
|
122 |
+
dw = torch.log(gw / pw)
|
123 |
+
dh = torch.log(gh / ph)
|
124 |
+
deltas = torch.stack([dx, dy, dw, dh], dim=-1)
|
125 |
+
|
126 |
+
means = deltas.new_tensor(means).unsqueeze(0)
|
127 |
+
stds = deltas.new_tensor(stds).unsqueeze(0)
|
128 |
+
deltas = deltas.sub_(means).div_(stds)
|
129 |
+
|
130 |
+
return deltas
|
131 |
+
|
132 |
+
|
133 |
+
@mmcv.jit(coderize=True)
|
134 |
+
def delta2bbox(rois,
|
135 |
+
deltas,
|
136 |
+
means=(0., 0., 0., 0.),
|
137 |
+
stds=(1., 1., 1., 1.),
|
138 |
+
max_shape=None,
|
139 |
+
wh_ratio_clip=16 / 1000,
|
140 |
+
clip_border=True):
|
141 |
+
"""Apply deltas to shift/scale base boxes.
|
142 |
+
|
143 |
+
Typically the rois are anchor or proposed bounding boxes and the deltas are
|
144 |
+
network outputs used to shift/scale those boxes.
|
145 |
+
This is the inverse function of :func:`bbox2delta`.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
|
149 |
+
deltas (Tensor): Encoded offsets with respect to each roi.
|
150 |
+
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
151 |
+
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
152 |
+
when rois is a grid of anchors.Offset encoding follows [1]_.
|
153 |
+
means (Sequence[float]): Denormalizing means for delta coordinates
|
154 |
+
stds (Sequence[float]): Denormalizing standard deviation for delta
|
155 |
+
coordinates
|
156 |
+
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
157 |
+
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
158 |
+
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
|
159 |
+
the max_shape should be a Sequence[Sequence[int]]
|
160 |
+
and the length of max_shape should also be B.
|
161 |
+
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
162 |
+
clip_border (bool, optional): Whether clip the objects outside the
|
163 |
+
border of the image. Defaults to True.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
|
167 |
+
(N, num_classes * 4) or (N, 4), where 4 represent
|
168 |
+
tl_x, tl_y, br_x, br_y.
|
169 |
+
|
170 |
+
References:
|
171 |
+
.. [1] https://arxiv.org/abs/1311.2524
|
172 |
+
|
173 |
+
Example:
|
174 |
+
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
|
175 |
+
>>> [ 0., 0., 1., 1.],
|
176 |
+
>>> [ 0., 0., 1., 1.],
|
177 |
+
>>> [ 5., 5., 5., 5.]])
|
178 |
+
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
|
179 |
+
>>> [ 1., 1., 1., 1.],
|
180 |
+
>>> [ 0., 0., 2., -1.],
|
181 |
+
>>> [ 0.7, -1.9, -0.5, 0.3]])
|
182 |
+
>>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
|
183 |
+
tensor([[0.0000, 0.0000, 1.0000, 1.0000],
|
184 |
+
[0.1409, 0.1409, 2.8591, 2.8591],
|
185 |
+
[0.0000, 0.3161, 4.1945, 0.6839],
|
186 |
+
[5.0000, 5.0000, 5.0000, 5.0000]])
|
187 |
+
"""
|
188 |
+
means = deltas.new_tensor(means).view(1,
|
189 |
+
-1).repeat(1,
|
190 |
+
deltas.size(-1) // 4)
|
191 |
+
stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
|
192 |
+
denorm_deltas = deltas * stds + means
|
193 |
+
dx = denorm_deltas[..., 0::4]
|
194 |
+
dy = denorm_deltas[..., 1::4]
|
195 |
+
dw = denorm_deltas[..., 2::4]
|
196 |
+
dh = denorm_deltas[..., 3::4]
|
197 |
+
max_ratio = np.abs(np.log(wh_ratio_clip))
|
198 |
+
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
199 |
+
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
200 |
+
x1, y1 = rois[..., 0], rois[..., 1]
|
201 |
+
x2, y2 = rois[..., 2], rois[..., 3]
|
202 |
+
# Compute center of each roi
|
203 |
+
px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
|
204 |
+
py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
|
205 |
+
# Compute width/height of each roi
|
206 |
+
pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
|
207 |
+
ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
|
208 |
+
# Use exp(network energy) to enlarge/shrink each roi
|
209 |
+
gw = pw * dw.exp()
|
210 |
+
gh = ph * dh.exp()
|
211 |
+
# Use network energy to shift the center of each roi
|
212 |
+
gx = px + pw * dx
|
213 |
+
gy = py + ph * dy
|
214 |
+
# Convert center-xy/width/height to top-left, bottom-right
|
215 |
+
x1 = gx - gw * 0.5
|
216 |
+
y1 = gy - gh * 0.5
|
217 |
+
x2 = gx + gw * 0.5
|
218 |
+
y2 = gy + gh * 0.5
|
219 |
+
|
220 |
+
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
221 |
+
|
222 |
+
if clip_border and max_shape is not None:
|
223 |
+
if not isinstance(max_shape, torch.Tensor):
|
224 |
+
max_shape = x1.new_tensor(max_shape)
|
225 |
+
max_shape = max_shape[..., :2].type_as(x1)
|
226 |
+
if max_shape.ndim == 2:
|
227 |
+
assert bboxes.ndim == 3
|
228 |
+
assert max_shape.size(0) == bboxes.size(0)
|
229 |
+
|
230 |
+
min_xy = x1.new_tensor(0)
|
231 |
+
max_xy = torch.cat(
|
232 |
+
[max_shape] * (deltas.size(-1) // 2),
|
233 |
+
dim=-1).flip(-1).unsqueeze(-2)
|
234 |
+
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
|
235 |
+
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
|
236 |
+
|
237 |
+
return bboxes
|
mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ..builder import BBOX_CODERS
|
6 |
+
from .base_bbox_coder import BaseBBoxCoder
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_CODERS.register_module()
|
10 |
+
class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
|
11 |
+
"""Legacy Delta XYWH BBox coder used in MMDet V1.x.
|
12 |
+
|
13 |
+
Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
|
14 |
+
y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
|
15 |
+
back to original bbox (x1, y1, x2, y2).
|
16 |
+
|
17 |
+
Note:
|
18 |
+
The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
|
19 |
+
:class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
|
20 |
+
height calculation. We suggest to only use this coder when testing with
|
21 |
+
MMDet V1.x models.
|
22 |
+
|
23 |
+
References:
|
24 |
+
.. [1] https://arxiv.org/abs/1311.2524
|
25 |
+
|
26 |
+
Args:
|
27 |
+
target_means (Sequence[float]): denormalizing means of target for
|
28 |
+
delta coordinates
|
29 |
+
target_stds (Sequence[float]): denormalizing standard deviation of
|
30 |
+
target for delta coordinates
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
target_means=(0., 0., 0., 0.),
|
35 |
+
target_stds=(1., 1., 1., 1.)):
|
36 |
+
super(BaseBBoxCoder, self).__init__()
|
37 |
+
self.means = target_means
|
38 |
+
self.stds = target_stds
|
39 |
+
|
40 |
+
def encode(self, bboxes, gt_bboxes):
|
41 |
+
"""Get box regression transformation deltas that can be used to
|
42 |
+
transform the ``bboxes`` into the ``gt_bboxes``.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
bboxes (torch.Tensor): source boxes, e.g., object proposals.
|
46 |
+
gt_bboxes (torch.Tensor): target of the transformation, e.g.,
|
47 |
+
ground-truth boxes.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: Box transformation deltas
|
51 |
+
"""
|
52 |
+
assert bboxes.size(0) == gt_bboxes.size(0)
|
53 |
+
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
|
54 |
+
encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
|
55 |
+
self.stds)
|
56 |
+
return encoded_bboxes
|
57 |
+
|
58 |
+
def decode(self,
|
59 |
+
bboxes,
|
60 |
+
pred_bboxes,
|
61 |
+
max_shape=None,
|
62 |
+
wh_ratio_clip=16 / 1000):
|
63 |
+
"""Apply transformation `pred_bboxes` to `boxes`.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
boxes (torch.Tensor): Basic boxes.
|
67 |
+
pred_bboxes (torch.Tensor): Encoded boxes with shape
|
68 |
+
max_shape (tuple[int], optional): Maximum shape of boxes.
|
69 |
+
Defaults to None.
|
70 |
+
wh_ratio_clip (float, optional): The allowed ratio between
|
71 |
+
width and height.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: Decoded boxes.
|
75 |
+
"""
|
76 |
+
assert pred_bboxes.size(0) == bboxes.size(0)
|
77 |
+
decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
|
78 |
+
self.stds, max_shape, wh_ratio_clip)
|
79 |
+
|
80 |
+
return decoded_bboxes
|
81 |
+
|
82 |
+
|
83 |
+
@mmcv.jit(coderize=True)
|
84 |
+
def legacy_bbox2delta(proposals,
|
85 |
+
gt,
|
86 |
+
means=(0., 0., 0., 0.),
|
87 |
+
stds=(1., 1., 1., 1.)):
|
88 |
+
"""Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
|
89 |
+
|
90 |
+
We usually compute the deltas of x, y, w, h of proposals w.r.t ground
|
91 |
+
truth bboxes to get regression target.
|
92 |
+
This is the inverse function of `delta2bbox()`
|
93 |
+
|
94 |
+
Args:
|
95 |
+
proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
|
96 |
+
gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
|
97 |
+
means (Sequence[float]): Denormalizing means for delta coordinates
|
98 |
+
stds (Sequence[float]): Denormalizing standard deviation for delta
|
99 |
+
coordinates
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Tensor: deltas with shape (N, 4), where columns represent dx, dy,
|
103 |
+
dw, dh.
|
104 |
+
"""
|
105 |
+
assert proposals.size() == gt.size()
|
106 |
+
|
107 |
+
proposals = proposals.float()
|
108 |
+
gt = gt.float()
|
109 |
+
px = (proposals[..., 0] + proposals[..., 2]) * 0.5
|
110 |
+
py = (proposals[..., 1] + proposals[..., 3]) * 0.5
|
111 |
+
pw = proposals[..., 2] - proposals[..., 0] + 1.0
|
112 |
+
ph = proposals[..., 3] - proposals[..., 1] + 1.0
|
113 |
+
|
114 |
+
gx = (gt[..., 0] + gt[..., 2]) * 0.5
|
115 |
+
gy = (gt[..., 1] + gt[..., 3]) * 0.5
|
116 |
+
gw = gt[..., 2] - gt[..., 0] + 1.0
|
117 |
+
gh = gt[..., 3] - gt[..., 1] + 1.0
|
118 |
+
|
119 |
+
dx = (gx - px) / pw
|
120 |
+
dy = (gy - py) / ph
|
121 |
+
dw = torch.log(gw / pw)
|
122 |
+
dh = torch.log(gh / ph)
|
123 |
+
deltas = torch.stack([dx, dy, dw, dh], dim=-1)
|
124 |
+
|
125 |
+
means = deltas.new_tensor(means).unsqueeze(0)
|
126 |
+
stds = deltas.new_tensor(stds).unsqueeze(0)
|
127 |
+
deltas = deltas.sub_(means).div_(stds)
|
128 |
+
|
129 |
+
return deltas
|
130 |
+
|
131 |
+
|
132 |
+
@mmcv.jit(coderize=True)
|
133 |
+
def legacy_delta2bbox(rois,
|
134 |
+
deltas,
|
135 |
+
means=(0., 0., 0., 0.),
|
136 |
+
stds=(1., 1., 1., 1.),
|
137 |
+
max_shape=None,
|
138 |
+
wh_ratio_clip=16 / 1000):
|
139 |
+
"""Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
|
140 |
+
|
141 |
+
Typically the rois are anchor or proposed bounding boxes and the deltas are
|
142 |
+
network outputs used to shift/scale those boxes.
|
143 |
+
This is the inverse function of `bbox2delta()`
|
144 |
+
|
145 |
+
Args:
|
146 |
+
rois (Tensor): Boxes to be transformed. Has shape (N, 4)
|
147 |
+
deltas (Tensor): Encoded offsets with respect to each roi.
|
148 |
+
Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
|
149 |
+
rois is a grid of anchors. Offset encoding follows [1]_.
|
150 |
+
means (Sequence[float]): Denormalizing means for delta coordinates
|
151 |
+
stds (Sequence[float]): Denormalizing standard deviation for delta
|
152 |
+
coordinates
|
153 |
+
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
|
154 |
+
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tensor: Boxes with shape (N, 4), where columns represent
|
158 |
+
tl_x, tl_y, br_x, br_y.
|
159 |
+
|
160 |
+
References:
|
161 |
+
.. [1] https://arxiv.org/abs/1311.2524
|
162 |
+
|
163 |
+
Example:
|
164 |
+
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
|
165 |
+
>>> [ 0., 0., 1., 1.],
|
166 |
+
>>> [ 0., 0., 1., 1.],
|
167 |
+
>>> [ 5., 5., 5., 5.]])
|
168 |
+
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
|
169 |
+
>>> [ 1., 1., 1., 1.],
|
170 |
+
>>> [ 0., 0., 2., -1.],
|
171 |
+
>>> [ 0.7, -1.9, -0.5, 0.3]])
|
172 |
+
>>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
|
173 |
+
tensor([[0.0000, 0.0000, 1.5000, 1.5000],
|
174 |
+
[0.0000, 0.0000, 5.2183, 5.2183],
|
175 |
+
[0.0000, 0.1321, 7.8891, 0.8679],
|
176 |
+
[5.3967, 2.4251, 6.0033, 3.7749]])
|
177 |
+
"""
|
178 |
+
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
|
179 |
+
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
|
180 |
+
denorm_deltas = deltas * stds + means
|
181 |
+
dx = denorm_deltas[:, 0::4]
|
182 |
+
dy = denorm_deltas[:, 1::4]
|
183 |
+
dw = denorm_deltas[:, 2::4]
|
184 |
+
dh = denorm_deltas[:, 3::4]
|
185 |
+
max_ratio = np.abs(np.log(wh_ratio_clip))
|
186 |
+
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
187 |
+
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
188 |
+
# Compute center of each roi
|
189 |
+
px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
|
190 |
+
py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
|
191 |
+
# Compute width/height of each roi
|
192 |
+
pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
|
193 |
+
ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
|
194 |
+
# Use exp(network energy) to enlarge/shrink each roi
|
195 |
+
gw = pw * dw.exp()
|
196 |
+
gh = ph * dh.exp()
|
197 |
+
# Use network energy to shift the center of each roi
|
198 |
+
gx = px + pw * dx
|
199 |
+
gy = py + ph * dy
|
200 |
+
# Convert center-xy/width/height to top-left, bottom-right
|
201 |
+
|
202 |
+
# The true legacy box coder should +- 0.5 here.
|
203 |
+
# However, current implementation improves the performance when testing
|
204 |
+
# the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
|
205 |
+
x1 = gx - gw * 0.5
|
206 |
+
y1 = gy - gh * 0.5
|
207 |
+
x2 = gx + gw * 0.5
|
208 |
+
y2 = gy + gh * 0.5
|
209 |
+
if max_shape is not None:
|
210 |
+
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
|
211 |
+
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
|
212 |
+
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
|
213 |
+
y2 = y2.clamp(min=0, max=max_shape[0] - 1)
|
214 |
+
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
|
215 |
+
return bboxes
|