Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest import TestCase
import mmcv
import pytest
from mmdet.datasets.transforms import * # noqa
from mmdet.registry import TRANSFORMS
class TestMuitiScaleFlipAug(TestCase):
def test_exception(self):
with pytest.raises(TypeError):
tta_transform = dict(
type='TestTimeAug',
transforms=[dict(type='Resize', keep_ratio=False)],
)
TRANSFORMS.build(tta_transform)
def test_multi_scale_flip_aug(self):
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=scale, keep_ratio=False)
for scale in [(256, 256), (512, 512), (1024, 1024)]
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor'))
]])
tta_module = TRANSFORMS.build(tta_transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img_id'] = '1'
results['img_path'] = 'data/color.jpg'
results['img'] = img
results['ori_shape'] = img.shape
results['ori_height'] = img.shape[0]
results['ori_width'] = img.shape[1]
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
tta_results = tta_module(results.copy())
assert [img.shape
for img in tta_results['inputs']] == [(3, 256, 256),
(3, 512, 512),
(3, 1024, 1024)]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=scale, keep_ratio=False)
for scale in [(256, 256), (512, 512), (1024, 1024)]
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results: dict = tta_module(results.copy())
assert [img.shape
for img in tta_results['inputs']] == [(3, 256, 256),
(3, 256, 256),
(3, 512, 512),
(3, 512, 512),
(3, 1024, 1024),
(3, 1024, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=(512, 512), keep_ratio=False)
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor'))
]])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
(3, 512, 512)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale_factor=r, keep_ratio=False)
for r in [0.5, 1.0, 2.0]
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor'))
]])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
(3, 288, 512),
(3, 576, 1024)]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in [0.5, 1.0, 2.0]
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
(3, 144, 256),
(3, 288, 512),
(3, 288, 512),
(3, 576, 1024),
(3, 576, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]