File size: 5,094 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.structures import InstanceData
from mmdet.models.task_modules.assigners.assign_result import AssignResult
from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner
from mmdet.registry import TASK_UTILS
@TASK_UTILS.register_module()
class TransMaxIoUAssigner(MaxIoUAssigner):
def assign(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
gt_instances_ignore: Optional[InstanceData] = None,
**kwargs) -> AssignResult:
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, or a semi-positive number. -1 means negative
sample, semi-positive number is the index (0-based) of assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to the background
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
pred_instances (:obj:`InstanceData`): Instances of model
predictions. It includes ``priors``, and the priors can
be anchors or points, or the bboxes predicted by the
previous stage, has shape (n, 4). The bboxes predicted by
the current model or stage will be named ``bboxes``,
``labels``, and ``scores``, the same as the ``InstanceData``
in other places.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes``, with shape (k, 4),
and ``labels``, with shape (k, ).
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes``
attribute data that is ignored during training and testing.
Defaults to None.
Returns:
:obj:`AssignResult`: The assign result.
Example:
>>> from mmengine.structures import InstanceData
>>> self = MaxIoUAssigner(0.5, 0.5)
>>> pred_instances = InstanceData()
>>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
... [10, 10, 20, 20]])
>>> gt_instances = InstanceData()
>>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]])
>>> gt_instances.labels = torch.Tensor([0])
>>> assign_result = self.assign(pred_instances, gt_instances)
>>> expected_gt_inds = torch.LongTensor([1, 0])
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
gt_labels = gt_instances.labels
if gt_instances_ignore is not None:
gt_bboxes_ignore = gt_instances_ignore.bboxes
else:
gt_bboxes_ignore = None
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
gt_bboxes.shape[0] > self.gpu_assign_thr) else False
# compute overlap and assign gt on CPU when number of GT is large
if assign_on_cpu:
device = priors.device
priors = priors.cpu()
gt_bboxes = gt_bboxes.cpu()
gt_labels = gt_labels.cpu()
if gt_bboxes_ignore is not None:
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
trans_priors = torch.cat([
priors[..., 1].view(-1, 1), priors[..., 0].view(-1, 1),
priors[..., 3].view(-1, 1), priors[..., 2].view(-1, 1)
],
dim=-1)
overlaps = self.iou_calculator(gt_bboxes, trans_priors)
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
and gt_bboxes_ignore.numel() > 0 and trans_priors.numel() > 0):
if self.ignore_wrt_candidates:
ignore_overlaps = self.iou_calculator(
trans_priors, gt_bboxes_ignore, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
else:
ignore_overlaps = self.iou_calculator(
gt_bboxes_ignore, trans_priors, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
if assign_on_cpu:
assign_result.gt_inds = assign_result.gt_inds.to(device)
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
if assign_result.labels is not None:
assign_result.labels = assign_result.labels.to(device)
return assign_result
|