Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.structures import BaseDataElement | |
| from mmdet.models.utils import multi_apply | |
| from mmdet.registry import MODELS, TASK_UTILS | |
| from mmdet.utils import reduce_mean | |
| class DDQAuxLoss(nn.Module): | |
| """DDQ auxiliary branches loss for dense queries. | |
| Args: | |
| loss_cls (dict): | |
| Configuration of classification loss function. | |
| loss_bbox (dict): | |
| Configuration of bbox regression loss function. | |
| train_cfg (dict): | |
| Configuration of gt targets assigner for each predicted bbox. | |
| """ | |
| def __init__( | |
| self, | |
| loss_cls=dict( | |
| type='QualityFocalLoss', | |
| use_sigmoid=True, | |
| activated=True, # use probability instead of logit as input | |
| beta=2.0, | |
| loss_weight=1.0), | |
| loss_bbox=dict(type='GIoULoss', loss_weight=2.0), | |
| train_cfg=dict( | |
| assigner=dict(type='TopkHungarianAssigner', topk=8), | |
| alpha=1, | |
| beta=6), | |
| ): | |
| super(DDQAuxLoss, self).__init__() | |
| self.train_cfg = train_cfg | |
| self.loss_cls = MODELS.build(loss_cls) | |
| self.loss_bbox = MODELS.build(loss_bbox) | |
| self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) | |
| sampler_cfg = dict(type='PseudoSampler') | |
| self.sampler = TASK_UTILS.build(sampler_cfg) | |
| def loss_single(self, cls_score, bbox_pred, labels, label_weights, | |
| bbox_targets, alignment_metrics): | |
| """Calculate auxiliary branches loss for dense queries for one image. | |
| Args: | |
| cls_score (Tensor): Predicted normalized classification | |
| scores for one image, has shape (num_dense_queries, | |
| cls_out_channels). | |
| bbox_pred (Tensor): Predicted unnormalized bbox coordinates | |
| for one image, has shape (num_dense_queries, 4) with the | |
| last dimension arranged as (x1, y1, x2, y2). | |
| labels (Tensor): Labels for one image. | |
| label_weights (Tensor): Label weights for one image. | |
| bbox_targets (Tensor): Bbox targets for one image. | |
| alignment_metrics (Tensor): Normalized alignment metrics for one | |
| image. | |
| Returns: | |
| tuple: A tuple of loss components and loss weights. | |
| """ | |
| bbox_targets = bbox_targets.reshape(-1, 4) | |
| labels = labels.reshape(-1) | |
| alignment_metrics = alignment_metrics.reshape(-1) | |
| label_weights = label_weights.reshape(-1) | |
| targets = (labels, alignment_metrics) | |
| cls_loss_func = self.loss_cls | |
| loss_cls = cls_loss_func( | |
| cls_score, targets, label_weights, avg_factor=1.0) | |
| # FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
| bg_class_ind = cls_score.size(-1) | |
| pos_inds = ((labels >= 0) | |
| & (labels < bg_class_ind)).nonzero().squeeze(1) | |
| if len(pos_inds) > 0: | |
| pos_bbox_targets = bbox_targets[pos_inds] | |
| pos_bbox_pred = bbox_pred[pos_inds] | |
| pos_decode_bbox_pred = pos_bbox_pred | |
| pos_decode_bbox_targets = pos_bbox_targets | |
| # regression loss | |
| pos_bbox_weight = alignment_metrics[pos_inds] | |
| loss_bbox = self.loss_bbox( | |
| pos_decode_bbox_pred, | |
| pos_decode_bbox_targets, | |
| weight=pos_bbox_weight, | |
| avg_factor=1.0) | |
| else: | |
| loss_bbox = bbox_pred.sum() * 0 | |
| pos_bbox_weight = bbox_targets.new_tensor(0.) | |
| return loss_cls, loss_bbox, alignment_metrics.sum( | |
| ), pos_bbox_weight.sum() | |
| def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, | |
| **kwargs): | |
| """Calculate auxiliary branches loss for dense queries. | |
| Args: | |
| cls_scores (Tensor): Predicted normalized classification | |
| scores, has shape (bs, num_dense_queries, | |
| cls_out_channels). | |
| bbox_preds (Tensor): Predicted unnormalized bbox coordinates, | |
| has shape (bs, num_dense_queries, 4) with the last | |
| dimension arranged as (x1, y1, x2, y2). | |
| gt_bboxes (list[Tensor]): List of unnormalized ground truth | |
| bboxes for each image, each has shape (num_gt, 4) with the | |
| last dimension arranged as (x1, y1, x2, y2). | |
| NOTE: num_gt is dynamic for each image. | |
| gt_labels (list[Tensor]): List of ground truth classification | |
| index for each image, each has shape (num_gt,). | |
| NOTE: num_gt is dynamic for each image. | |
| img_metas (list[dict]): Meta information for one image, | |
| e.g., image size, scaling factor, etc. | |
| Returns: | |
| dict: A dictionary of loss components. | |
| """ | |
| flatten_cls_scores = cls_scores | |
| flatten_bbox_preds = bbox_preds | |
| cls_reg_targets = self.get_targets( | |
| flatten_cls_scores, | |
| flatten_bbox_preds, | |
| gt_bboxes, | |
| img_metas, | |
| gt_labels_list=gt_labels, | |
| ) | |
| (labels_list, label_weights_list, bbox_targets_list, | |
| alignment_metrics_list) = cls_reg_targets | |
| losses_cls, losses_bbox, \ | |
| cls_avg_factors, bbox_avg_factors = multi_apply( | |
| self.loss_single, | |
| flatten_cls_scores, | |
| flatten_bbox_preds, | |
| labels_list, | |
| label_weights_list, | |
| bbox_targets_list, | |
| alignment_metrics_list, | |
| ) | |
| cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() | |
| losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) | |
| bbox_avg_factor = reduce_mean( | |
| sum(bbox_avg_factors)).clamp_(min=1).item() | |
| losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) | |
| return dict(aux_loss_cls=losses_cls, aux_loss_bbox=losses_bbox) | |
| def get_targets(self, | |
| cls_scores, | |
| bbox_preds, | |
| gt_bboxes_list, | |
| img_metas, | |
| gt_labels_list=None, | |
| **kwargs): | |
| """Compute regression and classification targets for a batch images. | |
| Args: | |
| cls_scores (Tensor): Predicted normalized classification | |
| scores, has shape (bs, num_dense_queries, | |
| cls_out_channels). | |
| bbox_preds (Tensor): Predicted unnormalized bbox coordinates, | |
| has shape (bs, num_dense_queries, 4) with the last | |
| dimension arranged as (x1, y1, x2, y2). | |
| gt_bboxes_list (List[Tensor]): List of unnormalized ground truth | |
| bboxes for each image, each has shape (num_gt, 4) with the | |
| last dimension arranged as (x1, y1, x2, y2). | |
| NOTE: num_gt is dynamic for each image. | |
| img_metas (list[dict]): Meta information for one image, | |
| e.g., image size, scaling factor, etc. | |
| gt_labels_list (list[Tensor]): List of ground truth classification | |
| index for each image, each has shape (num_gt,). | |
| NOTE: num_gt is dynamic for each image. | |
| Default: None. | |
| Returns: | |
| tuple: a tuple containing the following targets. | |
| - all_labels (list[Tensor]): Labels for all images. | |
| - all_label_weights (list[Tensor]): Label weights for all images. | |
| - all_bbox_targets (list[Tensor]): Bbox targets for all images. | |
| - all_assign_metrics (list[Tensor]): Normalized alignment metrics | |
| for all images. | |
| """ | |
| (all_labels, all_label_weights, all_bbox_targets, | |
| all_assign_metrics) = multi_apply(self._get_target_single, cls_scores, | |
| bbox_preds, gt_bboxes_list, | |
| gt_labels_list, img_metas) | |
| return (all_labels, all_label_weights, all_bbox_targets, | |
| all_assign_metrics) | |
| def _get_target_single(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, | |
| img_meta, **kwargs): | |
| """Compute regression and classification targets for one image. | |
| Args: | |
| cls_scores (Tensor): Predicted normalized classification | |
| scores for one image, has shape (num_dense_queries, | |
| cls_out_channels). | |
| bbox_preds (Tensor): Predicted unnormalized bbox coordinates | |
| for one image, has shape (num_dense_queries, 4) with the | |
| last dimension arranged as (x1, y1, x2, y2). | |
| gt_bboxes (Tensor): Unnormalized ground truth | |
| bboxes for one image, has shape (num_gt, 4) with the | |
| last dimension arranged as (x1, y1, x2, y2). | |
| NOTE: num_gt is dynamic for each image. | |
| gt_labels (Tensor): Ground truth classification | |
| index for the image, has shape (num_gt,). | |
| NOTE: num_gt is dynamic for each image. | |
| img_meta (dict): Meta information for one image. | |
| Returns: | |
| tuple[Tensor]: a tuple containing the following for one image. | |
| - labels (Tensor): Labels for one image. | |
| - label_weights (Tensor): Label weights for one image. | |
| - bbox_targets (Tensor): Bbox targets for one image. | |
| - norm_alignment_metrics (Tensor): Normalized alignment | |
| metrics for one image. | |
| """ | |
| if len(gt_labels) == 0: | |
| num_valid_anchors = len(cls_scores) | |
| bbox_targets = torch.zeros_like(bbox_preds) | |
| labels = bbox_preds.new_full((num_valid_anchors, ), | |
| cls_scores.size(-1), | |
| dtype=torch.long) | |
| label_weights = bbox_preds.new_zeros( | |
| num_valid_anchors, dtype=torch.float) | |
| norm_alignment_metrics = bbox_preds.new_zeros( | |
| num_valid_anchors, dtype=torch.float) | |
| return (labels, label_weights, bbox_targets, | |
| norm_alignment_metrics) | |
| assign_result = self.assigner.assign(cls_scores, bbox_preds, gt_bboxes, | |
| gt_labels, img_meta) | |
| assign_ious = assign_result.max_overlaps | |
| assign_metrics = assign_result.assign_metrics | |
| pred_instances = BaseDataElement() | |
| gt_instances = BaseDataElement() | |
| pred_instances.bboxes = bbox_preds | |
| gt_instances.bboxes = gt_bboxes | |
| pred_instances.priors = cls_scores | |
| gt_instances.labels = gt_labels | |
| sampling_result = self.sampler.sample(assign_result, pred_instances, | |
| gt_instances) | |
| num_valid_anchors = len(cls_scores) | |
| bbox_targets = torch.zeros_like(bbox_preds) | |
| labels = bbox_preds.new_full((num_valid_anchors, ), | |
| cls_scores.size(-1), | |
| dtype=torch.long) | |
| label_weights = bbox_preds.new_zeros( | |
| num_valid_anchors, dtype=torch.float) | |
| norm_alignment_metrics = bbox_preds.new_zeros( | |
| num_valid_anchors, dtype=torch.float) | |
| pos_inds = sampling_result.pos_inds | |
| neg_inds = sampling_result.neg_inds | |
| if len(pos_inds) > 0: | |
| # point-based | |
| pos_bbox_targets = sampling_result.pos_gt_bboxes | |
| bbox_targets[pos_inds, :] = pos_bbox_targets | |
| if gt_labels is None: | |
| # Only dense_heads gives gt_labels as None | |
| # Foreground is the first class since v2.5.0 | |
| labels[pos_inds] = 0 | |
| else: | |
| labels[pos_inds] = gt_labels[ | |
| sampling_result.pos_assigned_gt_inds] | |
| label_weights[pos_inds] = 1.0 | |
| if len(neg_inds) > 0: | |
| label_weights[neg_inds] = 1.0 | |
| class_assigned_gt_inds = torch.unique( | |
| sampling_result.pos_assigned_gt_inds) | |
| for gt_inds in class_assigned_gt_inds: | |
| gt_class_inds = sampling_result.pos_assigned_gt_inds == gt_inds | |
| pos_alignment_metrics = assign_metrics[gt_class_inds] | |
| pos_ious = assign_ious[gt_class_inds] | |
| pos_norm_alignment_metrics = pos_alignment_metrics / ( | |
| pos_alignment_metrics.max() + 10e-8) * pos_ious.max() | |
| norm_alignment_metrics[ | |
| pos_inds[gt_class_inds]] = pos_norm_alignment_metrics | |
| return (labels, label_weights, bbox_targets, norm_alignment_metrics) | |