|
|
|
from typing import Optional |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.models.task_modules.assigners.assign_result import AssignResult |
|
from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner |
|
from mmdet.registry import TASK_UTILS |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class TransMaxIoUAssigner(MaxIoUAssigner): |
|
|
|
def assign(self, |
|
pred_instances: InstanceData, |
|
gt_instances: InstanceData, |
|
gt_instances_ignore: Optional[InstanceData] = None, |
|
**kwargs) -> AssignResult: |
|
"""Assign gt to bboxes. |
|
|
|
This method assign a gt bbox to every bbox (proposal/anchor), each bbox |
|
will be assigned with -1, or a semi-positive number. -1 means negative |
|
sample, semi-positive number is the index (0-based) of assigned gt. |
|
The assignment is done in following steps, the order matters. |
|
|
|
1. assign every bbox to the background |
|
2. assign proposals whose iou with all gts < neg_iou_thr to 0 |
|
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, |
|
assign it to that bbox |
|
4. for each gt bbox, assign its nearest proposals (may be more than |
|
one) to itself |
|
|
|
Args: |
|
pred_instances (:obj:`InstanceData`): Instances of model |
|
predictions. It includes ``priors``, and the priors can |
|
be anchors or points, or the bboxes predicted by the |
|
previous stage, has shape (n, 4). The bboxes predicted by |
|
the current model or stage will be named ``bboxes``, |
|
``labels``, and ``scores``, the same as the ``InstanceData`` |
|
in other places. |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It usually includes ``bboxes``, with shape (k, 4), |
|
and ``labels``, with shape (k, ). |
|
gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
|
to be ignored during training. It includes ``bboxes`` |
|
attribute data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
:obj:`AssignResult`: The assign result. |
|
|
|
Example: |
|
>>> from mmengine.structures import InstanceData |
|
>>> self = MaxIoUAssigner(0.5, 0.5) |
|
>>> pred_instances = InstanceData() |
|
>>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10], |
|
... [10, 10, 20, 20]]) |
|
>>> gt_instances = InstanceData() |
|
>>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]]) |
|
>>> gt_instances.labels = torch.Tensor([0]) |
|
>>> assign_result = self.assign(pred_instances, gt_instances) |
|
>>> expected_gt_inds = torch.LongTensor([1, 0]) |
|
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds) |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
priors = pred_instances.priors |
|
gt_labels = gt_instances.labels |
|
if gt_instances_ignore is not None: |
|
gt_bboxes_ignore = gt_instances_ignore.bboxes |
|
else: |
|
gt_bboxes_ignore = None |
|
|
|
assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( |
|
gt_bboxes.shape[0] > self.gpu_assign_thr) else False |
|
|
|
if assign_on_cpu: |
|
device = priors.device |
|
priors = priors.cpu() |
|
gt_bboxes = gt_bboxes.cpu() |
|
gt_labels = gt_labels.cpu() |
|
if gt_bboxes_ignore is not None: |
|
gt_bboxes_ignore = gt_bboxes_ignore.cpu() |
|
|
|
trans_priors = torch.cat([ |
|
priors[..., 1].view(-1, 1), priors[..., 0].view(-1, 1), |
|
priors[..., 3].view(-1, 1), priors[..., 2].view(-1, 1) |
|
], |
|
dim=-1) |
|
overlaps = self.iou_calculator(gt_bboxes, trans_priors) |
|
|
|
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None |
|
and gt_bboxes_ignore.numel() > 0 and trans_priors.numel() > 0): |
|
if self.ignore_wrt_candidates: |
|
ignore_overlaps = self.iou_calculator( |
|
trans_priors, gt_bboxes_ignore, mode='iof') |
|
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) |
|
else: |
|
ignore_overlaps = self.iou_calculator( |
|
gt_bboxes_ignore, trans_priors, mode='iof') |
|
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) |
|
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 |
|
|
|
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) |
|
if assign_on_cpu: |
|
assign_result.gt_inds = assign_result.gt_inds.to(device) |
|
assign_result.max_overlaps = assign_result.max_overlaps.to(device) |
|
if assign_result.labels is not None: |
|
assign_result.labels = assign_result.labels.to(device) |
|
return assign_result |
|
|