File size: 7,972 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
from typing import List, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.models import BaseDetector
from mmdet.models.utils import unpack_gt_instances
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import ConfigType, OptConfigType


@torch.jit.script
def rescoring_mask(scores, mask_pred, masks):
    mask_pred_ = mask_pred.float()
    return scores * ((masks * mask_pred_).sum([1, 2]) /
                     (mask_pred_.sum([1, 2]) + 1e-6))


@MODELS.register_module()
class SparseInst(BaseDetector):
    """Implementation of `SparseInst <https://arxiv.org/abs/1912.02424>`_

    Args:
        data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
            :class:`DetDataPreprocessor` to process the input data.
            Defaults to None.
        backbone (:obj:`ConfigDict` or dict): The backbone module.
        encoder (:obj:`ConfigDict` or dict): The encoder module.
        decoder (:obj:`ConfigDict` or dict): The decoder module.
        criterion (:obj:`ConfigDict` or dict, optional): The training matcher
            and losses. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
            of SparseInst. Defaults to None.
        init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
            the initialization. Defaults to None.
    """

    def __init__(self,
                 data_preprocessor: ConfigType,
                 backbone: ConfigType,
                 encoder: ConfigType,
                 decoder: ConfigType,
                 criterion: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)

        # backbone
        self.backbone = MODELS.build(backbone)
        # encoder & decoder
        self.encoder = MODELS.build(encoder)
        self.decoder = MODELS.build(decoder)

        # matcher & loss (matcher is built in loss)
        self.criterion = MODELS.build(criterion)

        # inference
        self.cls_threshold = test_cfg.score_thr
        self.mask_threshold = test_cfg.mask_thr_binary

    def _forward(
            self,
            batch_inputs: Tensor,
            batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.

         Args:
            batch_inputs (Tensor): Inputs with shape (N, C, H, W).

        Returns:
            tuple[list]: A tuple of features from ``bbox_head`` forward.
        """
        x = self.backbone(batch_inputs)
        x = self.encoder(x)
        results = self.decoder(x)
        return results

    def predict(self,
                batch_inputs: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = True) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            batch_inputs (Tensor): Inputs with shape (N, C, H, W).
            batch_data_samples (List[:obj:`DetDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.

        Returns:
            list[:obj:`DetDataSample`]: Detection results of the
            input images. Each DetDataSample usually contain
            'pred_instances'. And the ``pred_instances`` usually
            contains following keys.

                - scores (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                    the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        max_shape = batch_inputs.shape[-2:]
        output = self._forward(batch_inputs)

        pred_scores = output['pred_logits'].sigmoid()
        pred_masks = output['pred_masks'].sigmoid()
        pred_objectness = output['pred_scores'].sigmoid()
        pred_scores = torch.sqrt(pred_scores * pred_objectness)

        results_list = []
        for batch_idx, (scores_per_image, mask_pred_per_image,
                        datasample) in enumerate(
                            zip(pred_scores, pred_masks, batch_data_samples)):
            result = InstanceData()
            # max/argmax
            scores, labels = scores_per_image.max(dim=-1)
            # cls threshold
            keep = scores > self.cls_threshold
            scores = scores[keep]
            labels = labels[keep]
            mask_pred_per_image = mask_pred_per_image[keep]

            if scores.size(0) == 0:
                result.scores = scores
                result.labels = labels
                results_list.append(result)
                continue

            img_meta = datasample.metainfo
            # rescoring mask using maskness
            scores = rescoring_mask(scores,
                                    mask_pred_per_image > self.mask_threshold,
                                    mask_pred_per_image)
            h, w = img_meta['img_shape'][:2]
            mask_pred_per_image = F.interpolate(
                mask_pred_per_image.unsqueeze(1),
                size=max_shape,
                mode='bilinear',
                align_corners=False)[:, :, :h, :w]

            if rescale:
                ori_h, ori_w = img_meta['ori_shape'][:2]
                mask_pred_per_image = F.interpolate(
                    mask_pred_per_image,
                    size=(ori_h, ori_w),
                    mode='bilinear',
                    align_corners=False).squeeze(1)

            mask_pred = mask_pred_per_image > self.mask_threshold
            result.masks = mask_pred
            result.scores = scores
            result.labels = labels
            # create an empty bbox in InstanceData to avoid bugs when
            # calculating metrics.
            result.bboxes = result.scores.new_zeros(len(scores), 4)
            results_list.append(result)

        batch_data_samples = self.add_pred_to_datasample(
            batch_data_samples, results_list)
        return batch_data_samples

    def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            batch_inputs (Tensor): Input images of shape (N, C, H, W).
                These should usually be mean centered and std scaled.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.

        Returns:
            dict: A dictionary of loss components.
        """
        outs = self._forward(batch_inputs)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = unpack_gt_instances(batch_data_samples)

        losses = self.criterion(outs, batch_gt_instances, batch_img_metas,
                                batch_gt_instances_ignore)
        return losses

    def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
        """Extract features.

        Args:
            batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).

        Returns:
            tuple[Tensor]: Multi-level features that may have
            different resolutions.
        """
        x = self.backbone(batch_inputs)
        x = self.encoder(x)
        return x