File size: 9,212 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved

import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from torch.cuda.amp import autocast

from mmdet.registry import MODELS, TASK_UTILS
from mmdet.utils import reduce_mean


def compute_mask_iou(inputs, targets):
    inputs = inputs.sigmoid()
    # thresholding
    binarized_inputs = (inputs >= 0.4).float()
    targets = (targets > 0.5).float()
    intersection = (binarized_inputs * targets).sum(-1)
    union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection
    score = intersection / (union + 1e-6)
    return score


def dice_score(inputs, targets):
    inputs = inputs.sigmoid()
    numerator = 2 * torch.matmul(inputs, targets.t())
    denominator = (inputs * inputs).sum(-1)[:,
                                            None] + (targets * targets).sum(-1)
    score = numerator / (denominator + 1e-4)
    return score


@MODELS.register_module()
class SparseInstCriterion(nn.Module):
    """This part is partially derivated from:

    https://github.com/facebookresearch/detr/blob/main/models/detr.py.
    """

    def __init__(
        self,
        num_classes,
        assigner,
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            alpha=0.25,
            gamma=2.0,
            reduction='sum',
            loss_weight=2.0),
        loss_obj=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=1.0),
        loss_mask=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=5.0),
        loss_dice=dict(
            type='DiceLoss',
            use_sigmoid=True,
            reduction='sum',
            eps=5e-5,
            loss_weight=2.0),
    ):
        super().__init__()
        self.matcher = TASK_UTILS.build(assigner)
        self.num_classes = num_classes
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_obj = MODELS.build(loss_obj)
        self.loss_mask = MODELS.build(loss_mask)
        self.loss_dice = MODELS.build(loss_dice)

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat(
            [torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat(
            [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def loss_classification(self, outputs, batch_gt_instances, indices,
                            num_instances):
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)])
        target_classes = torch.full(
            src_logits.shape[:2],
            self.num_classes,
            dtype=torch.int64,
            device=src_logits.device)
        target_classes[idx] = target_classes_o

        src_logits = src_logits.flatten(0, 1)
        target_classes = target_classes.flatten(0, 1)
        # comp focal loss.
        class_loss = self.loss_cls(
            src_logits,
            target_classes,
        ) / num_instances
        return class_loss

    def loss_masks_with_iou_objectness(self, outputs, batch_gt_instances,
                                       indices, num_instances):
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        # Bx100xHxW
        assert 'pred_masks' in outputs
        assert 'pred_scores' in outputs
        src_iou_scores = outputs['pred_scores']
        src_masks = outputs['pred_masks']
        with torch.no_grad():
            target_masks = torch.cat([
                gt.masks.to_tensor(
                    dtype=src_masks.dtype, device=src_masks.device)
                for gt in batch_gt_instances
            ])
        num_masks = [len(gt.masks) for gt in batch_gt_instances]
        target_masks = target_masks.to(src_masks)
        if len(target_masks) == 0:

            loss_dice = src_masks.sum() * 0.0
            loss_mask = src_masks.sum() * 0.0
            loss_objectness = src_iou_scores.sum() * 0.0

            return loss_objectness, loss_dice, loss_mask

        src_masks = src_masks[src_idx]
        target_masks = F.interpolate(
            target_masks[:, None],
            size=src_masks.shape[-2:],
            mode='bilinear',
            align_corners=False).squeeze(1)

        src_masks = src_masks.flatten(1)
        # FIXME: tgt_idx
        mix_tgt_idx = torch.zeros_like(tgt_idx[1])
        cum_sum = 0
        for num_mask in num_masks:
            mix_tgt_idx[cum_sum:cum_sum + num_mask] = cum_sum
            cum_sum += num_mask
        mix_tgt_idx += tgt_idx[1]

        target_masks = target_masks[mix_tgt_idx].flatten(1)

        with torch.no_grad():
            ious = compute_mask_iou(src_masks, target_masks)

        tgt_iou_scores = ious
        src_iou_scores = src_iou_scores[src_idx]
        tgt_iou_scores = tgt_iou_scores.flatten(0)
        src_iou_scores = src_iou_scores.flatten(0)

        loss_objectness = self.loss_obj(src_iou_scores, tgt_iou_scores)
        loss_dice = self.loss_dice(src_masks, target_masks) / num_instances
        loss_mask = self.loss_mask(src_masks, target_masks)

        return loss_objectness, loss_dice, loss_mask

    def forward(self, outputs, batch_gt_instances, batch_img_metas,
                batch_gt_instances_ignore):
        # Retrieve the matching between the outputs of
        # the last layer and the targets
        indices = self.matcher(outputs, batch_gt_instances)
        # Compute the average number of target boxes
        # across all nodes, for normalization purposes
        num_instances = sum(gt.labels.shape[0] for gt in batch_gt_instances)
        num_instances = torch.as_tensor([num_instances],
                                        dtype=torch.float,
                                        device=next(iter(
                                            outputs.values())).device)
        num_instances = reduce_mean(num_instances).clamp_(min=1).item()
        # Compute all the requested losses
        loss_cls = self.loss_classification(outputs, batch_gt_instances,
                                            indices, num_instances)
        loss_obj, loss_dice, loss_mask = self.loss_masks_with_iou_objectness(
            outputs, batch_gt_instances, indices, num_instances)

        return dict(
            loss_cls=loss_cls,
            loss_obj=loss_obj,
            loss_dice=loss_dice,
            loss_mask=loss_mask)


@TASK_UTILS.register_module()
class SparseInstMatcher(nn.Module):

    def __init__(self, alpha=0.8, beta=0.2):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.mask_score = dice_score

    def forward(self, outputs, batch_gt_instances):
        with torch.no_grad():
            B, N, H, W = outputs['pred_masks'].shape
            pred_masks = outputs['pred_masks']
            pred_logits = outputs['pred_logits'].sigmoid()
            device = pred_masks.device

            tgt_ids = torch.cat([gt.labels for gt in batch_gt_instances])

            if tgt_ids.shape[0] == 0:
                return [(torch.as_tensor([]).to(pred_logits),
                         torch.as_tensor([]).to(pred_logits))] * B
            tgt_masks = torch.cat([
                gt.masks.to_tensor(dtype=pred_masks.dtype, device=device)
                for gt in batch_gt_instances
            ])

            tgt_masks = F.interpolate(
                tgt_masks[:, None],
                size=pred_masks.shape[-2:],
                mode='bilinear',
                align_corners=False).squeeze(1)

            pred_masks = pred_masks.view(B * N, -1)
            tgt_masks = tgt_masks.flatten(1)
            with autocast(enabled=False):
                pred_masks = pred_masks.float()
                tgt_masks = tgt_masks.float()
                pred_logits = pred_logits.float()
                mask_score = self.mask_score(pred_masks, tgt_masks)
                # Nx(Number of gts)
                matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids]
                C = (mask_score**self.alpha) * (matching_prob**self.beta)

            C = C.view(B, N, -1).cpu()
            # hungarian matching
            sizes = [len(gt.masks) for gt in batch_gt_instances]
            indices = [
                linear_sum_assignment(c[i], maximize=True)
                for i, c in enumerate(C.split(sizes, -1))
            ]
            indices = [(torch.as_tensor(i, dtype=torch.int64),
                        torch.as_tensor(j, dtype=torch.int64))
                       for i, j in indices]
            return indices