|
|
|
import os.path as osp |
|
from unittest import TestCase |
|
|
|
import mmcv |
|
import pytest |
|
|
|
from mmdet.datasets.transforms import * |
|
from mmdet.registry import TRANSFORMS |
|
|
|
|
|
class TestMuitiScaleFlipAug(TestCase): |
|
|
|
def test_exception(self): |
|
with pytest.raises(TypeError): |
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[dict(type='Resize', keep_ratio=False)], |
|
) |
|
TRANSFORMS.build(tta_transform) |
|
|
|
def test_multi_scale_flip_aug(self): |
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[[ |
|
dict(type='Resize', scale=scale, keep_ratio=False) |
|
for scale in [(256, 256), (512, 512), (1024, 1024)] |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor')) |
|
]]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
|
|
results = dict() |
|
img = mmcv.imread( |
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') |
|
results['img_id'] = '1' |
|
results['img_path'] = 'data/color.jpg' |
|
results['img'] = img |
|
results['ori_shape'] = img.shape |
|
results['ori_height'] = img.shape[0] |
|
results['ori_width'] = img.shape[1] |
|
|
|
results['pad_shape'] = img.shape |
|
results['scale_factor'] = 1.0 |
|
|
|
tta_results = tta_module(results.copy()) |
|
assert [img.shape |
|
for img in tta_results['inputs']] == [(3, 256, 256), |
|
(3, 512, 512), |
|
(3, 1024, 1024)] |
|
|
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[ |
|
[ |
|
dict(type='Resize', scale=scale, keep_ratio=False) |
|
for scale in [(256, 256), (512, 512), (1024, 1024)] |
|
], |
|
[ |
|
dict(type='RandomFlip', prob=0., direction='horizontal'), |
|
dict(type='RandomFlip', prob=1., direction='horizontal') |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor', 'flip', |
|
'flip_direction')) |
|
] |
|
]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
tta_results: dict = tta_module(results.copy()) |
|
assert [img.shape |
|
for img in tta_results['inputs']] == [(3, 256, 256), |
|
(3, 256, 256), |
|
(3, 512, 512), |
|
(3, 512, 512), |
|
(3, 1024, 1024), |
|
(3, 1024, 1024)] |
|
assert [ |
|
data_sample.metainfo['flip'] |
|
for data_sample in tta_results['data_samples'] |
|
] == [False, True, False, True, False, True] |
|
|
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[[ |
|
dict(type='Resize', scale=(512, 512), keep_ratio=False) |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor')) |
|
]]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
tta_results = tta_module(results.copy()) |
|
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)] |
|
|
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[ |
|
[dict(type='Resize', scale=(512, 512), keep_ratio=False)], |
|
[ |
|
dict(type='RandomFlip', prob=0., direction='horizontal'), |
|
dict(type='RandomFlip', prob=1., direction='horizontal') |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor', 'flip', |
|
'flip_direction')) |
|
] |
|
]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
tta_results = tta_module(results.copy()) |
|
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512), |
|
(3, 512, 512)] |
|
assert [ |
|
data_sample.metainfo['flip'] |
|
for data_sample in tta_results['data_samples'] |
|
] == [False, True] |
|
|
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[[ |
|
dict(type='Resize', scale_factor=r, keep_ratio=False) |
|
for r in [0.5, 1.0, 2.0] |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor')) |
|
]]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
tta_results = tta_module(results.copy()) |
|
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), |
|
(3, 288, 512), |
|
(3, 576, 1024)] |
|
|
|
tta_transform = dict( |
|
type='TestTimeAug', |
|
transforms=[ |
|
[ |
|
dict(type='Resize', scale_factor=r, keep_ratio=True) |
|
for r in [0.5, 1.0, 2.0] |
|
], |
|
[ |
|
dict(type='RandomFlip', prob=0., direction='horizontal'), |
|
dict(type='RandomFlip', prob=1., direction='horizontal') |
|
], |
|
[ |
|
dict( |
|
type='mmdet.PackDetInputs', |
|
meta_keys=('img_id', 'img_path', 'ori_shape', |
|
'img_shape', 'scale_factor', 'flip', |
|
'flip_direction')) |
|
] |
|
]) |
|
tta_module = TRANSFORMS.build(tta_transform) |
|
tta_results = tta_module(results.copy()) |
|
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), |
|
(3, 144, 256), |
|
(3, 288, 512), |
|
(3, 288, 512), |
|
(3, 576, 1024), |
|
(3, 576, 1024)] |
|
assert [ |
|
data_sample.metainfo['flip'] |
|
for data_sample in tta_results['data_samples'] |
|
] == [False, True, False, True, False, True] |
|
|