File size: 15,229 Bytes
db3da1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, List, Tuple
import torch
from typing import List, Tuple, Union
import torch.nn.functional as F
from detectron2.config import configurable
from detectron2.utils.events import get_event_storage
from detectron2.layers import ShapeSpec, cat
from detectron2.structures import Boxes, Instances, pairwise_iou, pairwise_ioa
from detectron2.utils.memory import retry_if_cuda_oom
from fvcore.nn import smooth_l1_loss
from detectron2.layers import cat
from detectron2.layers import nonzero_tuple

from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss
from detectron2.modeling.proposal_generator import RPN
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY

@PROPOSAL_GENERATOR_REGISTRY.register()
class RPNWithIgnore(RPN):
    
    @configurable
    def __init__(
        self,
        *,
        ignore_thresh: float = 0.5,
        objectness_uncertainty: str = 'none',
        **kwargs
    ):
        super().__init__(**kwargs)
        self.ignore_thresh = ignore_thresh
        self.objectness_uncertainty = objectness_uncertainty

    @classmethod
    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
        ret = super().from_config(cfg, input_shape)
        ret["ignore_thresh"] = cfg.MODEL.RPN.IGNORE_THRESHOLD
        ret["objectness_uncertainty"] = cfg.MODEL.RPN.OBJECTNESS_UNCERTAINTY 
        return ret
    
    @torch.jit.unused
    @torch.no_grad()
    def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        
        anchors = Boxes.cat(anchors)

        # separate valid and ignore gts
        gt_boxes_ign = [x.gt_boxes[x.gt_classes < 0] for x in gt_instances]
        gt_boxes = [x.gt_boxes[x.gt_classes >= 0] for x in gt_instances]

        del gt_instances

        gt_labels = []
        matched_gt_boxes = []

        for gt_boxes_i, gt_boxes_ign_i in zip(gt_boxes, gt_boxes_ign):
            """
            gt_boxes_i: ground-truth boxes for i-th image
            gt_boxes_ign_i: ground-truth ignore boxes for i-th image
            """

            match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
            matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
            
            # Matching is memory-expensive and may result in CPU tensors. But the result is small
            gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)

            gt_arange = torch.arange(match_quality_matrix.shape[1]).to(matched_idxs.device)
            matched_ious = match_quality_matrix[matched_idxs, gt_arange]
            
            best_ious_gt_vals, best_ious_gt_ind = match_quality_matrix.max(dim=1)

            del match_quality_matrix

            best_inds = torch.tensor(list(set(best_ious_gt_ind.tolist()) & set((gt_labels_i == 1).nonzero().squeeze(1).tolist())))

            # A vector of labels (-1, 0, 1) for each anchor
            # which denote (ignore, background, foreground)
            gt_labels_i = self._subsample_labels(gt_labels_i, matched_ious=matched_ious)

            # overrride the best possible GT options, always selected for sampling.
            # otherwise aggressive thresholds may produce HUGE amounts of low quality FG.
            if best_inds.numel() > 0:
                gt_labels_i[best_inds] = 1.0

            if len(gt_boxes_i) == 0:
                # These values won't be used anyway since the anchor is labeled as background
                matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
            else:
                # TODO wasted indexing computation for ignored boxes
                matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor

            if len(gt_boxes_ign_i) > 0: 

                # compute the quality matrix, only on subset of background
                background_inds = (gt_labels_i == 0).nonzero().squeeze()

                if background_inds.numel() > 1:
                    
                    match_quality_matrix_ign = retry_if_cuda_oom(pairwise_ioa)(gt_boxes_ign_i, anchors[background_inds])

                    # determine the boxes inside ignore regions with sufficient threshold
                    gt_labels_i[background_inds[match_quality_matrix_ign.max(0)[0] >= self.ignore_thresh]] = -1
                
                    del match_quality_matrix_ign

            gt_labels.append(gt_labels_i)  # N,AHW
            matched_gt_boxes.append(matched_gt_boxes_i)

        return gt_labels, matched_gt_boxes
    
    def _subsample_labels(self, label, matched_ious=None):
        """
        Randomly sample a subset of positive and negative examples, and overwrite
        the label vector to the ignore value (-1) for all elements that are not
        included in the sample.
        Args:
            labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
        """
        pos_idx, neg_idx = subsample_labels(
            label, self.batch_size_per_image, self.positive_fraction, 0, matched_ious=matched_ious
        )
        # Fill with the ignore label (-1), then set positive and negative labels
        label.fill_(-1)
        label.scatter_(0, pos_idx, 1)
        label.scatter_(0, neg_idx, 0)
        return label

    @torch.jit.unused
    def losses(
        self,
        anchors: List[Boxes],
        pred_objectness_logits: List[torch.Tensor],
        gt_labels: List[torch.Tensor],
        pred_anchor_deltas: List[torch.Tensor],
        gt_boxes: List[torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """
        Return the losses from a set of RPN predictions and their associated ground-truth.

        Args:
            anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each
                has shape (Hi*Wi*A, B), where B is box dimension (4 or 5).
            pred_objectness_logits (list[Tensor]): A list of L elements.
                Element i is a tensor of shape (N, Hi*Wi*A) representing
                the predicted objectness logits for all anchors.
            gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`.
            pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape
                (N, Hi*Wi*A, 4 or 5) representing the predicted "deltas" used to transform anchors
                to proposals.
            gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`.

        Returns:
            dict[loss name -> loss value]: A dict mapping from loss name to loss value.
                Loss names are: `loss_rpn_cls` for objectness classification and
                `loss_rpn_loc` for proposal localization.
        """
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)  # (N, sum(Hi*Wi*Ai))

        # Log the number of positive/negative anchors per-image that's used in training
        pos_mask = gt_labels == 1
        num_pos_anchors = pos_mask.sum().item()
        num_neg_anchors = (gt_labels == 0).sum().item()
        storage = get_event_storage()
        storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images)
        storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images)

        if not self.objectness_uncertainty.lower() in ['none']:
            localization_loss, objectness_loss = _dense_box_regression_loss_with_uncertainty(
                anchors,
                self.box2box_transform,
                pred_anchor_deltas,
                pred_objectness_logits,
                gt_boxes,
                pos_mask,
                box_reg_loss_type=self.box_reg_loss_type,
                smooth_l1_beta=self.smooth_l1_beta,
                uncertainty_type=self.objectness_uncertainty,
            )
        else:
            localization_loss = _dense_box_regression_loss(
                anchors,
                self.box2box_transform,
                pred_anchor_deltas,
                gt_boxes,
                pos_mask,
                box_reg_loss_type=self.box_reg_loss_type,
                smooth_l1_beta=self.smooth_l1_beta,
            )

            valid_mask = gt_labels >= 0
            objectness_loss = F.binary_cross_entropy_with_logits(
                cat(pred_objectness_logits, dim=1)[valid_mask],
                gt_labels[valid_mask].to(torch.float32),
                reduction="sum",
            )
        normalizer = self.batch_size_per_image * num_images
        losses = {
            "rpn/cls": objectness_loss / normalizer,
            "rpn/loc": localization_loss / normalizer,
        }
        losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
        return losses

def _dense_box_regression_loss_with_uncertainty(
    anchors: List[Union[Boxes, torch.Tensor]],
    box2box_transform: Box2BoxTransform,
    pred_anchor_deltas: List[torch.Tensor],
    pred_objectness_logits: List[torch.Tensor],
    gt_boxes: List[torch.Tensor],
    fg_mask: torch.Tensor,
    box_reg_loss_type="smooth_l1",
    smooth_l1_beta=0.0,
    uncertainty_type='centerness',
):
    """
    Compute loss for dense multi-level box regression.
    Loss is accumulated over ``fg_mask``.
    Args:
        anchors: #lvl anchor boxes, each is (HixWixA, 4)
        pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4)
        gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A))
        fg_mask: the foreground boolean mask of shape (N, R) to compute loss on
        box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou",
            "diou", "ciou".
        smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
            use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
    """
    if isinstance(anchors[0], Boxes):
        anchors = type(anchors[0]).cat(anchors).tensor  # (R, 4)
    else:
        anchors = cat(anchors)

    n = len(gt_boxes)
    
    boxes_fg = Boxes(anchors.unsqueeze(0).repeat([n, 1, 1])[fg_mask])
    gt_boxes_fg = Boxes(torch.stack(gt_boxes)[fg_mask].detach())
    objectness_targets_anchors = matched_pairwise_iou(boxes_fg, gt_boxes_fg).detach()

    objectness_logits = torch.cat(pred_objectness_logits, dim=1)

    # Numerically the same as (-(y*torch.log(p) + (1 - y)*torch.log(1 - p))).sum()
    loss_box_conf = F.binary_cross_entropy_with_logits(
        objectness_logits[fg_mask], 
        objectness_targets_anchors,
        reduction='none'
    )

    loss_box_conf = (loss_box_conf * objectness_targets_anchors).sum()
    
    # keep track of how scores look for FG / BG.
    # ideally, FG slowly >>> BG scores as regression improves. 
    storage = get_event_storage()
    storage.put_scalar("rpn/conf_pos_anchors", torch.sigmoid(objectness_logits[fg_mask]).mean().item())
    storage.put_scalar("rpn/conf_neg_anchors", torch.sigmoid(objectness_logits[~fg_mask]).mean().item())

    if box_reg_loss_type == "smooth_l1":
        gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 4)
        loss_box_reg = smooth_l1_loss(
            cat(pred_anchor_deltas, dim=1)[fg_mask],
            gt_anchor_deltas[fg_mask],
            beta=smooth_l1_beta,
            reduction="none",
        )
        
        loss_box_reg = (loss_box_reg.sum(dim=1) * objectness_targets_anchors).sum()

    else:
        raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")

    return loss_box_reg, loss_box_conf

def subsample_labels(
    labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int, matched_ious=None, eps=1e-4
):
    """
    Return `num_samples` (or fewer, if not enough found)
    random samples from `labels` which is a mixture of positives & negatives.
    It will try to return as many positives as possible without
    exceeding `positive_fraction * num_samples`, and then try to
    fill the remaining slots with negatives.
    Args:
        labels (Tensor): (N, ) label vector with values:
            * -1: ignore
            * bg_label: background ("negative") class
            * otherwise: one or more foreground ("positive") classes
        num_samples (int): The total number of labels with value >= 0 to return.
            Values that are not sampled will be filled with -1 (ignore).
        positive_fraction (float): The number of subsampled labels with values > 0
            is `min(num_positives, int(positive_fraction * num_samples))`. The number
            of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
            In order words, if there are not enough positives, the sample is filled with
            negatives. If there are also not enough negatives, then as many elements are
            sampled as is possible.
        bg_label (int): label index of background ("negative") class.
    Returns:
        pos_idx, neg_idx (Tensor):
            1D vector of indices. The total length of both is `num_samples` or fewer.
    """
    positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
    negative = nonzero_tuple(labels == bg_label)[0]

    num_pos = int(num_samples * positive_fraction)
    # protect against not enough positive examples
    num_pos = min(positive.numel(), num_pos)
    num_neg = num_samples - num_pos
    # protect against not enough negative examples
    num_neg = min(negative.numel(), num_neg)

    #if positive_fraction == 1.0 and num_neg > 10:
    # allow some negatives for statistics only.
    #num_neg = 10
    
    # randomly select positive and negative examples
    if num_pos > 0 and matched_ious is not None:
        perm1 = torch.multinomial(matched_ious[positive] + eps, num_pos)
    else:
        perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
    if num_neg > 0 and matched_ious is not None:
        perm2 = torch.multinomial(matched_ious[negative] + eps, num_neg)
    else:
        perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]

    pos_idx = positive[perm1]
    neg_idx = negative[perm2]
    return pos_idx, neg_idx

def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
    """
    Compute pairwise intersection over union (IOU) of two sets of matched
    boxes that have the same number of boxes.
    Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
    Args:
        boxes1 (Boxes): bounding boxes, sized [N,4].
        boxes2 (Boxes): same length as boxes1
    Returns:
        Tensor: iou, sized [N].
    """
    assert len(boxes1) == len(
        boxes2
    ), "boxlists should have the same" "number of entries, got {}, {}".format(
        len(boxes1), len(boxes2)
    )
    area1 = boxes1.area()  # [N]
    area2 = boxes2.area()  # [N]
    box1, box2 = boxes1.tensor, boxes2.tensor
    lt = torch.max(box1[:, :2], box2[:, :2])  # [N,2]
    rb = torch.min(box1[:, 2:], box2[:, 2:])  # [N,2]
    wh = (rb - lt).clamp(min=0)  # [N,2]
    inter = wh[:, 0] * wh[:, 1]  # [N]
    iou = inter / (area1 + area2 - inter)  # [N]
    return iou