|
from unittest import TestCase |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
from mmengine.structures import InstanceData, PixelData |
|
|
|
from mmdet.structures import DetDataSample |
|
|
|
|
|
def _equal(a, b): |
|
if isinstance(a, (torch.Tensor, np.ndarray)): |
|
return (a == b).all() |
|
else: |
|
return a == b |
|
|
|
|
|
class TestDetDataSample(TestCase): |
|
|
|
def test_init(self): |
|
meta_info = dict( |
|
img_size=[256, 256], |
|
scale_factor=np.array([1.5, 1.5]), |
|
img_shape=torch.rand(4)) |
|
|
|
det_data_sample = DetDataSample(metainfo=meta_info) |
|
assert 'img_size' in det_data_sample |
|
assert det_data_sample.img_size == [256, 256] |
|
assert det_data_sample.get('img_size') == [256, 256] |
|
|
|
def test_setter(self): |
|
det_data_sample = DetDataSample() |
|
|
|
gt_instances_data = dict( |
|
bboxes=torch.rand(4, 4), |
|
labels=torch.rand(4), |
|
masks=np.random.rand(4, 2, 2)) |
|
gt_instances = InstanceData(**gt_instances_data) |
|
det_data_sample.gt_instances = gt_instances |
|
assert 'gt_instances' in det_data_sample |
|
assert _equal(det_data_sample.gt_instances.bboxes, |
|
gt_instances_data['bboxes']) |
|
assert _equal(det_data_sample.gt_instances.labels, |
|
gt_instances_data['labels']) |
|
assert _equal(det_data_sample.gt_instances.masks, |
|
gt_instances_data['masks']) |
|
|
|
|
|
pred_instances_data = dict( |
|
bboxes=torch.rand(2, 4), |
|
labels=torch.rand(2), |
|
masks=np.random.rand(2, 2, 2)) |
|
pred_instances = InstanceData(**pred_instances_data) |
|
det_data_sample.pred_instances = pred_instances |
|
assert 'pred_instances' in det_data_sample |
|
assert _equal(det_data_sample.pred_instances.bboxes, |
|
pred_instances_data['bboxes']) |
|
assert _equal(det_data_sample.pred_instances.labels, |
|
pred_instances_data['labels']) |
|
assert _equal(det_data_sample.pred_instances.masks, |
|
pred_instances_data['masks']) |
|
|
|
|
|
proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4)) |
|
proposals = InstanceData(**proposals_data) |
|
det_data_sample.proposals = proposals |
|
assert 'proposals' in det_data_sample |
|
assert _equal(det_data_sample.proposals.bboxes, |
|
proposals_data['bboxes']) |
|
assert _equal(det_data_sample.proposals.labels, |
|
proposals_data['labels']) |
|
|
|
|
|
ignored_instances_data = dict( |
|
bboxes=torch.rand(4, 4), labels=torch.rand(4)) |
|
ignored_instances = InstanceData(**ignored_instances_data) |
|
det_data_sample.ignored_instances = ignored_instances |
|
assert 'ignored_instances' in det_data_sample |
|
assert _equal(det_data_sample.ignored_instances.bboxes, |
|
ignored_instances_data['bboxes']) |
|
assert _equal(det_data_sample.ignored_instances.labels, |
|
ignored_instances_data['labels']) |
|
|
|
|
|
gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4)) |
|
gt_panoptic_seg = PixelData(**gt_panoptic_seg_data) |
|
det_data_sample.gt_panoptic_seg = gt_panoptic_seg |
|
assert 'gt_panoptic_seg' in det_data_sample |
|
assert _equal(det_data_sample.gt_panoptic_seg.panoptic_seg, |
|
gt_panoptic_seg_data['panoptic_seg']) |
|
|
|
|
|
pred_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4)) |
|
pred_panoptic_seg = PixelData(**pred_panoptic_seg_data) |
|
det_data_sample.pred_panoptic_seg = pred_panoptic_seg |
|
assert 'pred_panoptic_seg' in det_data_sample |
|
assert _equal(det_data_sample.pred_panoptic_seg.panoptic_seg, |
|
pred_panoptic_seg_data['panoptic_seg']) |
|
|
|
|
|
gt_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2)) |
|
gt_segm_seg = PixelData(**gt_segm_seg_data) |
|
det_data_sample.gt_segm_seg = gt_segm_seg |
|
assert 'gt_segm_seg' in det_data_sample |
|
assert _equal(det_data_sample.gt_segm_seg.segm_seg, |
|
gt_segm_seg_data['segm_seg']) |
|
|
|
|
|
pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2)) |
|
pred_segm_seg = PixelData(**pred_segm_seg_data) |
|
det_data_sample.pred_segm_seg = pred_segm_seg |
|
assert 'pred_segm_seg' in det_data_sample |
|
assert _equal(det_data_sample.pred_segm_seg.segm_seg, |
|
pred_segm_seg_data['segm_seg']) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
det_data_sample.pred_instances = torch.rand(2, 4) |
|
|
|
with pytest.raises(AssertionError): |
|
det_data_sample.pred_panoptic_seg = torch.rand(2, 4) |
|
|
|
with pytest.raises(AssertionError): |
|
det_data_sample.pred_sem_seg = torch.rand(2, 4) |
|
|
|
def test_deleter(self): |
|
gt_instances_data = dict( |
|
bboxes=torch.rand(4, 4), |
|
labels=torch.rand(4), |
|
masks=np.random.rand(4, 2, 2)) |
|
|
|
det_data_sample = DetDataSample() |
|
gt_instances = InstanceData(data=gt_instances_data) |
|
det_data_sample.gt_instances = gt_instances |
|
assert 'gt_instances' in det_data_sample |
|
del det_data_sample.gt_instances |
|
assert 'gt_instances' not in det_data_sample |
|
|
|
pred_panoptic_seg_data = torch.rand(5, 4) |
|
pred_panoptic_seg = PixelData(data=pred_panoptic_seg_data) |
|
det_data_sample.pred_panoptic_seg = pred_panoptic_seg |
|
assert 'pred_panoptic_seg' in det_data_sample |
|
del det_data_sample.pred_panoptic_seg |
|
assert 'pred_panoptic_seg' not in det_data_sample |
|
|
|
pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2)) |
|
pred_segm_seg = PixelData(**pred_segm_seg_data) |
|
det_data_sample.pred_segm_seg = pred_segm_seg |
|
assert 'pred_segm_seg' in det_data_sample |
|
del det_data_sample.pred_segm_seg |
|
assert 'pred_segm_seg' not in det_data_sample |
|
|