File size: 6,181 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.structures import InstanceData, PixelData
from mmdet.structures import DetDataSample
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestDetDataSample(TestCase):
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
det_data_sample = DetDataSample(metainfo=meta_info)
assert 'img_size' in det_data_sample
assert det_data_sample.img_size == [256, 256]
assert det_data_sample.get('img_size') == [256, 256]
def test_setter(self):
det_data_sample = DetDataSample()
# test gt_instances
gt_instances_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
gt_instances = InstanceData(**gt_instances_data)
det_data_sample.gt_instances = gt_instances
assert 'gt_instances' in det_data_sample
assert _equal(det_data_sample.gt_instances.bboxes,
gt_instances_data['bboxes'])
assert _equal(det_data_sample.gt_instances.labels,
gt_instances_data['labels'])
assert _equal(det_data_sample.gt_instances.masks,
gt_instances_data['masks'])
# test pred_instances
pred_instances_data = dict(
bboxes=torch.rand(2, 4),
labels=torch.rand(2),
masks=np.random.rand(2, 2, 2))
pred_instances = InstanceData(**pred_instances_data)
det_data_sample.pred_instances = pred_instances
assert 'pred_instances' in det_data_sample
assert _equal(det_data_sample.pred_instances.bboxes,
pred_instances_data['bboxes'])
assert _equal(det_data_sample.pred_instances.labels,
pred_instances_data['labels'])
assert _equal(det_data_sample.pred_instances.masks,
pred_instances_data['masks'])
# test proposals
proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4))
proposals = InstanceData(**proposals_data)
det_data_sample.proposals = proposals
assert 'proposals' in det_data_sample
assert _equal(det_data_sample.proposals.bboxes,
proposals_data['bboxes'])
assert _equal(det_data_sample.proposals.labels,
proposals_data['labels'])
# test ignored_instances
ignored_instances_data = dict(
bboxes=torch.rand(4, 4), labels=torch.rand(4))
ignored_instances = InstanceData(**ignored_instances_data)
det_data_sample.ignored_instances = ignored_instances
assert 'ignored_instances' in det_data_sample
assert _equal(det_data_sample.ignored_instances.bboxes,
ignored_instances_data['bboxes'])
assert _equal(det_data_sample.ignored_instances.labels,
ignored_instances_data['labels'])
# test gt_panoptic_seg
gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
det_data_sample.gt_panoptic_seg = gt_panoptic_seg
assert 'gt_panoptic_seg' in det_data_sample
assert _equal(det_data_sample.gt_panoptic_seg.panoptic_seg,
gt_panoptic_seg_data['panoptic_seg'])
# test pred_panoptic_seg
pred_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
pred_panoptic_seg = PixelData(**pred_panoptic_seg_data)
det_data_sample.pred_panoptic_seg = pred_panoptic_seg
assert 'pred_panoptic_seg' in det_data_sample
assert _equal(det_data_sample.pred_panoptic_seg.panoptic_seg,
pred_panoptic_seg_data['panoptic_seg'])
# test gt_sem_seg
gt_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
gt_segm_seg = PixelData(**gt_segm_seg_data)
det_data_sample.gt_segm_seg = gt_segm_seg
assert 'gt_segm_seg' in det_data_sample
assert _equal(det_data_sample.gt_segm_seg.segm_seg,
gt_segm_seg_data['segm_seg'])
# test pred_segm_seg
pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
pred_segm_seg = PixelData(**pred_segm_seg_data)
det_data_sample.pred_segm_seg = pred_segm_seg
assert 'pred_segm_seg' in det_data_sample
assert _equal(det_data_sample.pred_segm_seg.segm_seg,
pred_segm_seg_data['segm_seg'])
# test type error
with pytest.raises(AssertionError):
det_data_sample.pred_instances = torch.rand(2, 4)
with pytest.raises(AssertionError):
det_data_sample.pred_panoptic_seg = torch.rand(2, 4)
with pytest.raises(AssertionError):
det_data_sample.pred_sem_seg = torch.rand(2, 4)
def test_deleter(self):
gt_instances_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
det_data_sample = DetDataSample()
gt_instances = InstanceData(data=gt_instances_data)
det_data_sample.gt_instances = gt_instances
assert 'gt_instances' in det_data_sample
del det_data_sample.gt_instances
assert 'gt_instances' not in det_data_sample
pred_panoptic_seg_data = torch.rand(5, 4)
pred_panoptic_seg = PixelData(data=pred_panoptic_seg_data)
det_data_sample.pred_panoptic_seg = pred_panoptic_seg
assert 'pred_panoptic_seg' in det_data_sample
del det_data_sample.pred_panoptic_seg
assert 'pred_panoptic_seg' not in det_data_sample
pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
pred_segm_seg = PixelData(**pred_segm_seg_data)
det_data_sample.pred_segm_seg = pred_segm_seg
assert 'pred_segm_seg' in det_data_sample
del det_data_sample.pred_segm_seg
assert 'pred_segm_seg' not in det_data_sample
|