Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Dict, List, Tuple, Union | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import OptSampleList, SampleList | |
| from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
| from .base import BaseDetector | |
| class DetectionTransformer(BaseDetector, metaclass=ABCMeta): | |
| r"""Base class for Detection Transformer. | |
| In Detection Transformer, an encoder is used to process output features of | |
| neck, then several queries interact with the encoder features using a | |
| decoder and do the regression and classification with the bounding box | |
| head. | |
| Args: | |
| backbone (:obj:`ConfigDict` or dict): Config of the backbone. | |
| neck (:obj:`ConfigDict` or dict, optional): Config of the neck. | |
| Defaults to None. | |
| encoder (:obj:`ConfigDict` or dict, optional): Config of the | |
| Transformer encoder. Defaults to None. | |
| decoder (:obj:`ConfigDict` or dict, optional): Config of the | |
| Transformer decoder. Defaults to None. | |
| bbox_head (:obj:`ConfigDict` or dict, optional): Config for the | |
| bounding box head module. Defaults to None. | |
| positional_encoding (:obj:`ConfigDict` or dict, optional): Config | |
| of the positional encoding module. Defaults to None. | |
| num_queries (int, optional): Number of decoder query in Transformer. | |
| Defaults to 100. | |
| train_cfg (:obj:`ConfigDict` or dict, optional): Training config of | |
| the bounding box head module. Defaults to None. | |
| test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
| the bounding box head module. Defaults to None. | |
| data_preprocessor (dict or ConfigDict, optional): The pre-process | |
| config of :class:`BaseDataPreprocessor`. it usually includes, | |
| ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. | |
| Defaults to None. | |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control | |
| the initialization. Defaults to None. | |
| """ | |
| def __init__(self, | |
| backbone: ConfigType, | |
| neck: OptConfigType = None, | |
| encoder: OptConfigType = None, | |
| decoder: OptConfigType = None, | |
| bbox_head: OptConfigType = None, | |
| positional_encoding: OptConfigType = None, | |
| num_queries: int = 100, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None) -> None: | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| # process args | |
| bbox_head.update(train_cfg=train_cfg) | |
| bbox_head.update(test_cfg=test_cfg) | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.positional_encoding = positional_encoding | |
| self.num_queries = num_queries | |
| # init model layers | |
| self.backbone = MODELS.build(backbone) | |
| if neck is not None: | |
| self.neck = MODELS.build(neck) | |
| self.bbox_head = MODELS.build(bbox_head) | |
| self._init_layers() | |
| def _init_layers(self) -> None: | |
| """Initialize layers except for backbone, neck and bbox_head.""" | |
| pass | |
| def loss(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> Union[dict, list]: | |
| """Calculate losses from a batch of inputs and data samples. | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (bs, dim, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| img_feats = self.extract_feat(batch_inputs) | |
| head_inputs_dict = self.forward_transformer(img_feats, | |
| batch_data_samples) | |
| losses = self.bbox_head.loss( | |
| **head_inputs_dict, batch_data_samples=batch_data_samples) | |
| return losses | |
| def predict(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| rescale: bool = True) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Detection results of the input images. | |
| Each DetDataSample usually contain 'pred_instances'. And the | |
| `pred_instances` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| """ | |
| img_feats = self.extract_feat(batch_inputs) | |
| head_inputs_dict = self.forward_transformer(img_feats, | |
| batch_data_samples) | |
| results_list = self.bbox_head.predict( | |
| **head_inputs_dict, | |
| rescale=rescale, | |
| batch_data_samples=batch_data_samples) | |
| batch_data_samples = self.add_pred_to_datasample( | |
| batch_data_samples, results_list) | |
| return batch_data_samples | |
| def _forward( | |
| self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`], optional): The | |
| batch data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Defaults to None. | |
| Returns: | |
| tuple[Tensor]: A tuple of features from ``bbox_head`` forward. | |
| """ | |
| img_feats = self.extract_feat(batch_inputs) | |
| head_inputs_dict = self.forward_transformer(img_feats, | |
| batch_data_samples) | |
| results = self.bbox_head.forward(**head_inputs_dict) | |
| return results | |
| def forward_transformer(self, | |
| img_feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList = None) -> Dict: | |
| """Forward process of Transformer, which includes four steps: | |
| 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We | |
| summarized the parameters flow of the existing DETR-like detector, | |
| which can be illustrated as follow: | |
| .. code:: text | |
| img_feats & batch_data_samples | |
| | | |
| V | |
| +-----------------+ | |
| | pre_transformer | | |
| +-----------------+ | |
| | | | |
| | V | |
| | +-----------------+ | |
| | | forward_encoder | | |
| | +-----------------+ | |
| | | | |
| | V | |
| | +---------------+ | |
| | | pre_decoder | | |
| | +---------------+ | |
| | | | | |
| V V | | |
| +-----------------+ | | |
| | forward_decoder | | | |
| +-----------------+ | | |
| | | | |
| V V | |
| head_inputs_dict | |
| Args: | |
| img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each | |
| feature map has shape (bs, dim, H, W). | |
| batch_data_samples (list[:obj:`DetDataSample`], optional): The | |
| batch data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Defaults to None. | |
| Returns: | |
| dict: The dictionary of bbox_head function inputs, which always | |
| includes the `hidden_states` of the decoder output and may contain | |
| `references` including the initial and intermediate references. | |
| """ | |
| encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( | |
| img_feats, batch_data_samples) | |
| encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) | |
| tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict) | |
| decoder_inputs_dict.update(tmp_dec_in) | |
| decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) | |
| head_inputs_dict.update(decoder_outputs_dict) | |
| return head_inputs_dict | |
| def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: | |
| """Extract features. | |
| Args: | |
| batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W). | |
| Returns: | |
| tuple[Tensor]: Tuple of feature maps from neck. Each feature map | |
| has shape (bs, dim, H, W). | |
| """ | |
| x = self.backbone(batch_inputs) | |
| if self.with_neck: | |
| x = self.neck(x) | |
| return x | |
| def pre_transformer( | |
| self, | |
| img_feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]: | |
| """Process image features before feeding them to the transformer. | |
| Args: | |
| img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each | |
| feature map has shape (bs, dim, H, W). | |
| batch_data_samples (list[:obj:`DetDataSample`], optional): The | |
| batch data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Defaults to None. | |
| Returns: | |
| tuple[dict, dict]: The first dict contains the inputs of encoder | |
| and the second dict contains the inputs of decoder. | |
| - encoder_inputs_dict (dict): The keyword args dictionary of | |
| `self.forward_encoder()`, which includes 'feat', 'feat_mask', | |
| 'feat_pos', and other algorithm-specific arguments. | |
| - decoder_inputs_dict (dict): The keyword args dictionary of | |
| `self.forward_decoder()`, which includes 'memory_mask', and | |
| other algorithm-specific arguments. | |
| """ | |
| pass | |
| def forward_encoder(self, feat: Tensor, feat_mask: Tensor, | |
| feat_pos: Tensor, **kwargs) -> Dict: | |
| """Forward with Transformer encoder. | |
| Args: | |
| feat (Tensor): Sequential features, has shape (bs, num_feat_points, | |
| dim). | |
| feat_mask (Tensor): ByteTensor, the padding mask of the features, | |
| has shape (bs, num_feat_points). | |
| feat_pos (Tensor): The positional embeddings of the features, has | |
| shape (bs, num_feat_points, dim). | |
| Returns: | |
| dict: The dictionary of encoder outputs, which includes the | |
| `memory` of the encoder output and other algorithm-specific | |
| arguments. | |
| """ | |
| pass | |
| def pre_decoder(self, memory: Tensor, **kwargs) -> Tuple[Dict, Dict]: | |
| """Prepare intermediate variables before entering Transformer decoder, | |
| such as `query`, `query_pos`, and `reference_points`. | |
| Args: | |
| memory (Tensor): The output embeddings of the Transformer encoder, | |
| has shape (bs, num_feat_points, dim). | |
| Returns: | |
| tuple[dict, dict]: The first dict contains the inputs of decoder | |
| and the second dict contains the inputs of the bbox_head function. | |
| - decoder_inputs_dict (dict): The keyword dictionary args of | |
| `self.forward_decoder()`, which includes 'query', 'query_pos', | |
| 'memory', and other algorithm-specific arguments. | |
| - head_inputs_dict (dict): The keyword dictionary args of the | |
| bbox_head functions, which is usually empty, or includes | |
| `enc_outputs_class` and `enc_outputs_class` when the detector | |
| support 'two stage' or 'query selection' strategies. | |
| """ | |
| pass | |
| def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, | |
| **kwargs) -> Dict: | |
| """Forward with Transformer decoder. | |
| Args: | |
| query (Tensor): The queries of decoder inputs, has shape | |
| (bs, num_queries, dim). | |
| query_pos (Tensor): The positional queries of decoder inputs, | |
| has shape (bs, num_queries, dim). | |
| memory (Tensor): The output embeddings of the Transformer encoder, | |
| has shape (bs, num_feat_points, dim). | |
| Returns: | |
| dict: The dictionary of decoder outputs, which includes the | |
| `hidden_states` of the decoder output, `references` including | |
| the initial and intermediate reference_points, and other | |
| algorithm-specific arguments. | |
| """ | |
| pass | |