|
|
|
from unittest import TestCase |
|
|
|
import torch |
|
from mmengine import ConfigDict |
|
|
|
from mmdet.models import DetTTAModel |
|
from mmdet.registry import MODELS |
|
from mmdet.structures import DetDataSample |
|
from mmdet.testing import get_detector_cfg |
|
from mmdet.utils import register_all_modules |
|
|
|
|
|
class TestDetTTAModel(TestCase): |
|
|
|
def setUp(self): |
|
register_all_modules() |
|
|
|
def test_det_tta_model(self): |
|
|
|
detector_cfg = get_detector_cfg( |
|
'retinanet/retinanet_r18_fpn_1x_coco.py') |
|
cfg = ConfigDict( |
|
type='DetTTAModel', |
|
module=detector_cfg, |
|
tta_cfg=dict( |
|
nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)) |
|
|
|
model: DetTTAModel = MODELS.build(cfg) |
|
|
|
imgs = [] |
|
data_samples = [] |
|
directions = ['horizontal', 'vertical'] |
|
for i in range(12): |
|
flip_direction = directions[0] if i % 3 == 0 else directions[1] |
|
imgs.append(torch.randn(1, 3, 100 + 10 * i, 100 + 10 * i)) |
|
data_samples.append([ |
|
DetDataSample( |
|
metainfo=dict( |
|
ori_shape=(100, 100), |
|
img_shape=(100 + 10 * i, 100 + 10 * i), |
|
scale_factor=((100 + 10 * i) / 100, |
|
(100 + 10 * i) / 100), |
|
flip=(i % 2 == 0), |
|
flip_direction=flip_direction), ) |
|
]) |
|
|
|
model.test_step(dict(inputs=imgs, data_samples=data_samples)) |
|
|