File size: 5,581 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import unittest

import torch
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from parameterized import parameterized

from mmdet.models.dense_heads import RepPointsHead
from mmdet.structures import DetDataSample


class TestRepPointsHead(unittest.TestCase):

    @parameterized.expand(['moment', 'minmax', 'partial_minmax'])
    def test_head_loss(self, transform_method='moment'):
        cfg = ConfigDict(
            dict(
                num_classes=2,
                in_channels=32,
                point_feat_channels=10,
                num_points=9,
                gradient_mul=0.1,
                point_strides=[8, 16, 32, 64, 128],
                point_base_scale=4,
                loss_cls=dict(
                    type='FocalLoss',
                    use_sigmoid=True,
                    gamma=2.0,
                    alpha=0.25,
                    loss_weight=1.0),
                loss_bbox_init=dict(
                    type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
                loss_bbox_refine=dict(
                    type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
                use_grid_points=False,
                center_init=True,
                transform_method=transform_method,
                moment_mul=0.01,
                init_cfg=dict(
                    type='Normal',
                    layer='Conv2d',
                    std=0.01,
                    override=dict(
                        type='Normal',
                        name='reppoints_cls_out',
                        std=0.01,
                        bias_prob=0.01)),
                train_cfg=dict(
                    init=dict(
                        assigner=dict(
                            type='PointAssigner', scale=4, pos_num=1),
                        allowed_border=-1,
                        pos_weight=-1,
                        debug=False),
                    refine=dict(
                        assigner=dict(
                            type='MaxIoUAssigner',
                            pos_iou_thr=0.5,
                            neg_iou_thr=0.4,
                            min_pos_iou=0,
                            ignore_iof_thr=-1),
                        allowed_border=-1,
                        pos_weight=-1,
                        debug=False)),
                test_cfg=dict(
                    nms_pre=1000,
                    min_bbox_size=0,
                    score_thr=0.05,
                    nms=dict(type='nms', iou_threshold=0.5),
                    max_per_img=100)))
        reppoints_head = RepPointsHead(**cfg)
        s = 256
        img_metas = [{
            'img_shape': (s, s),
            'scale_factor': (1, 1),
            'pad_shape': (s, s),
            'batch_input_shape': (s, s)
        }]
        x = [
            torch.rand(1, 32, s // 2**(i + 2), s // 2**(i + 2))
            for i in range(5)
        ]

        # Test that empty ground truth encourages the network to
        # predict background
        gt_instances = InstanceData()
        gt_instances.bboxes = torch.empty((0, 4))
        gt_instances.labels = torch.LongTensor([])
        gt_bboxes_ignore = None

        reppoints_head.train()
        forward_outputs = reppoints_head.forward(x)
        empty_gt_losses = reppoints_head.loss_by_feat(*forward_outputs,
                                                      [gt_instances],
                                                      img_metas,
                                                      gt_bboxes_ignore)
        # When there is no truth, the cls loss should be nonzero but there
        # should be no pts loss.
        for key, losses in empty_gt_losses.items():
            for loss in losses:
                if 'cls' in key:
                    self.assertGreater(loss.item(), 0,
                                       'cls loss should be non-zero')
                elif 'pts' in key:
                    self.assertEqual(
                        loss.item(), 0,
                        'there should be no reg loss when no ground true boxes'
                    )

        # When truth is non-empty then both cls and pts loss should be nonzero
        # for random inputs
        gt_instances = InstanceData()
        gt_instances.bboxes = torch.Tensor(
            [[23.6667, 23.8757, 238.6326, 151.8874]])
        gt_instances.labels = torch.LongTensor([2])
        one_gt_losses = reppoints_head.loss_by_feat(*forward_outputs,
                                                    [gt_instances], img_metas,
                                                    gt_bboxes_ignore)
        # loss_cls should all be non-zero
        self.assertTrue(
            all([loss.item() > 0 for loss in one_gt_losses['loss_cls']]))
        # only one level loss_pts_init is non-zero
        cnt_non_zero = 0
        for loss in one_gt_losses['loss_pts_init']:
            if loss.item() != 0:
                cnt_non_zero += 1
        self.assertEqual(cnt_non_zero, 1)

        # only one level loss_pts_refine is non-zero
        cnt_non_zero = 0
        for loss in one_gt_losses['loss_pts_init']:
            if loss.item() != 0:
                cnt_non_zero += 1
        self.assertEqual(cnt_non_zero, 1)

        # test loss
        samples = DetDataSample()
        samples.set_metainfo(img_metas[0])
        samples.gt_instances = gt_instances
        reppoints_head.loss(x, [samples])
        # test only predict
        reppoints_head.eval()
        reppoints_head.predict(x, [samples], rescale=True)