File size: 7,972 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models import BaseDetector
from mmdet.models.utils import unpack_gt_instances
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import ConfigType, OptConfigType
@torch.jit.script
def rescoring_mask(scores, mask_pred, masks):
mask_pred_ = mask_pred.float()
return scores * ((masks * mask_pred_).sum([1, 2]) /
(mask_pred_.sum([1, 2]) + 1e-6))
@MODELS.register_module()
class SparseInst(BaseDetector):
"""Implementation of `SparseInst <https://arxiv.org/abs/1912.02424>`_
Args:
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
:class:`DetDataPreprocessor` to process the input data.
Defaults to None.
backbone (:obj:`ConfigDict` or dict): The backbone module.
encoder (:obj:`ConfigDict` or dict): The encoder module.
decoder (:obj:`ConfigDict` or dict): The decoder module.
criterion (:obj:`ConfigDict` or dict, optional): The training matcher
and losses. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
of SparseInst. Defaults to None.
init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
the initialization. Defaults to None.
"""
def __init__(self,
data_preprocessor: ConfigType,
backbone: ConfigType,
encoder: ConfigType,
decoder: ConfigType,
criterion: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
# backbone
self.backbone = MODELS.build(backbone)
# encoder & decoder
self.encoder = MODELS.build(encoder)
self.decoder = MODELS.build(decoder)
# matcher & loss (matcher is built in loss)
self.criterion = MODELS.build(criterion)
# inference
self.cls_threshold = test_cfg.score_thr
self.mask_threshold = test_cfg.mask_thr_binary
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward.
"""
x = self.backbone(batch_inputs)
x = self.encoder(x)
results = self.decoder(x)
return results
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
max_shape = batch_inputs.shape[-2:]
output = self._forward(batch_inputs)
pred_scores = output['pred_logits'].sigmoid()
pred_masks = output['pred_masks'].sigmoid()
pred_objectness = output['pred_scores'].sigmoid()
pred_scores = torch.sqrt(pred_scores * pred_objectness)
results_list = []
for batch_idx, (scores_per_image, mask_pred_per_image,
datasample) in enumerate(
zip(pred_scores, pred_masks, batch_data_samples)):
result = InstanceData()
# max/argmax
scores, labels = scores_per_image.max(dim=-1)
# cls threshold
keep = scores > self.cls_threshold
scores = scores[keep]
labels = labels[keep]
mask_pred_per_image = mask_pred_per_image[keep]
if scores.size(0) == 0:
result.scores = scores
result.labels = labels
results_list.append(result)
continue
img_meta = datasample.metainfo
# rescoring mask using maskness
scores = rescoring_mask(scores,
mask_pred_per_image > self.mask_threshold,
mask_pred_per_image)
h, w = img_meta['img_shape'][:2]
mask_pred_per_image = F.interpolate(
mask_pred_per_image.unsqueeze(1),
size=max_shape,
mode='bilinear',
align_corners=False)[:, :, :h, :w]
if rescale:
ori_h, ori_w = img_meta['ori_shape'][:2]
mask_pred_per_image = F.interpolate(
mask_pred_per_image,
size=(ori_h, ori_w),
mode='bilinear',
align_corners=False).squeeze(1)
mask_pred = mask_pred_per_image > self.mask_threshold
result.masks = mask_pred
result.scores = scores
result.labels = labels
# create an empty bbox in InstanceData to avoid bugs when
# calculating metrics.
result.bboxes = result.scores.new_zeros(len(scores), 4)
results_list.append(result)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
outs = self._forward(batch_inputs)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = unpack_gt_instances(batch_data_samples)
losses = self.criterion(outs, batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
return losses
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
x = self.backbone(batch_inputs)
x = self.encoder(x)
return x
|