Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, | |
| reweight_loss_dict) | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import SampleList | |
| from mmdet.structures.bbox import bbox2roi, bbox_project | |
| from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig | |
| from ..utils.misc import unpack_gt_instances | |
| from .semi_base import SemiBaseDetector | |
| class SoftTeacher(SemiBaseDetector): | |
| r"""Implementation of `End-to-End Semi-Supervised Object Detection | |
| with Soft Teacher <https://arxiv.org/abs/2106.09018>`_ | |
| Args: | |
| detector (:obj:`ConfigDict` or dict): The detector config. | |
| semi_train_cfg (:obj:`ConfigDict` or dict, optional): | |
| The semi-supervised training config. | |
| semi_test_cfg (:obj:`ConfigDict` or dict, optional): | |
| The semi-supervised testing config. | |
| data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of | |
| :class:`DetDataPreprocessor` to process the input data. | |
| Defaults to None. | |
| init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
| list[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| detector: ConfigType, | |
| semi_train_cfg: OptConfigType = None, | |
| semi_test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None) -> None: | |
| super().__init__( | |
| detector=detector, | |
| semi_train_cfg=semi_train_cfg, | |
| semi_test_cfg=semi_test_cfg, | |
| data_preprocessor=data_preprocessor, | |
| init_cfg=init_cfg) | |
| def loss_by_pseudo_instances(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| batch_info: Optional[dict] = None) -> dict: | |
| """Calculate losses from a batch of inputs and pseudo data samples. | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| batch_info (dict): Batch information of teacher model | |
| forward propagation process. Defaults to None. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| x = self.student.extract_feat(batch_inputs) | |
| losses = {} | |
| rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances( | |
| x, batch_data_samples) | |
| losses.update(**rpn_losses) | |
| losses.update(**self.rcnn_cls_loss_by_pseudo_instances( | |
| x, rpn_results_list, batch_data_samples, batch_info)) | |
| losses.update(**self.rcnn_reg_loss_by_pseudo_instances( | |
| x, rpn_results_list, batch_data_samples)) | |
| unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) | |
| return rename_loss_dict('unsup_', | |
| reweight_loss_dict(losses, unsup_weight)) | |
| def get_pseudo_instances( | |
| self, batch_inputs: Tensor, batch_data_samples: SampleList | |
| ) -> Tuple[SampleList, Optional[dict]]: | |
| """Get pseudo instances from teacher model.""" | |
| assert self.teacher.with_bbox, 'Bbox head must be implemented.' | |
| x = self.teacher.extract_feat(batch_inputs) | |
| # If there are no pre-defined proposals, use RPN to get proposals | |
| if batch_data_samples[0].get('proposals', None) is None: | |
| rpn_results_list = self.teacher.rpn_head.predict( | |
| x, batch_data_samples, rescale=False) | |
| else: | |
| rpn_results_list = [ | |
| data_sample.proposals for data_sample in batch_data_samples | |
| ] | |
| results_list = self.teacher.roi_head.predict( | |
| x, rpn_results_list, batch_data_samples, rescale=False) | |
| for data_samples, results in zip(batch_data_samples, results_list): | |
| data_samples.gt_instances = results | |
| batch_data_samples = filter_gt_instances( | |
| batch_data_samples, | |
| score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr) | |
| reg_uncs_list = self.compute_uncertainty_with_aug( | |
| x, batch_data_samples) | |
| for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list): | |
| data_samples.gt_instances['reg_uncs'] = reg_uncs | |
| data_samples.gt_instances.bboxes = bbox_project( | |
| data_samples.gt_instances.bboxes, | |
| torch.from_numpy(data_samples.homography_matrix).inverse().to( | |
| self.data_preprocessor.device), data_samples.ori_shape) | |
| batch_info = { | |
| 'feat': x, | |
| 'img_shape': [], | |
| 'homography_matrix': [], | |
| 'metainfo': [] | |
| } | |
| for data_samples in batch_data_samples: | |
| batch_info['img_shape'].append(data_samples.img_shape) | |
| batch_info['homography_matrix'].append( | |
| torch.from_numpy(data_samples.homography_matrix).to( | |
| self.data_preprocessor.device)) | |
| batch_info['metainfo'].append(data_samples.metainfo) | |
| return batch_data_samples, batch_info | |
| def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor], | |
| batch_data_samples: SampleList) -> dict: | |
| """Calculate rpn loss from a batch of inputs and pseudo data samples. | |
| Args: | |
| x (tuple[Tensor]): Features from FPN. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| Returns: | |
| dict: A dictionary of rpn loss components | |
| """ | |
| rpn_data_samples = copy.deepcopy(batch_data_samples) | |
| rpn_data_samples = filter_gt_instances( | |
| rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr) | |
| proposal_cfg = self.student.train_cfg.get('rpn_proposal', | |
| self.student.test_cfg.rpn) | |
| # set cat_id of gt_labels to 0 in RPN | |
| for data_sample in rpn_data_samples: | |
| data_sample.gt_instances.labels = \ | |
| torch.zeros_like(data_sample.gt_instances.labels) | |
| rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict( | |
| x, rpn_data_samples, proposal_cfg=proposal_cfg) | |
| for key in rpn_losses.keys(): | |
| if 'loss' in key and 'rpn' not in key: | |
| rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) | |
| return rpn_losses, rpn_results_list | |
| def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor], | |
| unsup_rpn_results_list: InstanceList, | |
| batch_data_samples: SampleList, | |
| batch_info: dict) -> dict: | |
| """Calculate classification loss from a batch of inputs and pseudo data | |
| samples. | |
| Args: | |
| x (tuple[Tensor]): List of multi-level img features. | |
| unsup_rpn_results_list (list[:obj:`InstanceData`]): | |
| List of region proposals. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| batch_info (dict): Batch information of teacher model | |
| forward propagation process. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of rcnn | |
| classification loss components | |
| """ | |
| rpn_results_list = copy.deepcopy(unsup_rpn_results_list) | |
| cls_data_samples = copy.deepcopy(batch_data_samples) | |
| cls_data_samples = filter_gt_instances( | |
| cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) | |
| outputs = unpack_gt_instances(cls_data_samples) | |
| batch_gt_instances, batch_gt_instances_ignore, _ = outputs | |
| # assign gts and sample proposals | |
| num_imgs = len(cls_data_samples) | |
| sampling_results = [] | |
| for i in range(num_imgs): | |
| # rename rpn_results.bboxes to rpn_results.priors | |
| rpn_results = rpn_results_list[i] | |
| rpn_results.priors = rpn_results.pop('bboxes') | |
| assign_result = self.student.roi_head.bbox_assigner.assign( | |
| rpn_results, batch_gt_instances[i], | |
| batch_gt_instances_ignore[i]) | |
| sampling_result = self.student.roi_head.bbox_sampler.sample( | |
| assign_result, | |
| rpn_results, | |
| batch_gt_instances[i], | |
| feats=[lvl_feat[i][None] for lvl_feat in x]) | |
| sampling_results.append(sampling_result) | |
| selected_bboxes = [res.priors for res in sampling_results] | |
| rois = bbox2roi(selected_bboxes) | |
| bbox_results = self.student.roi_head._bbox_forward(x, rois) | |
| # cls_reg_targets is a tuple of labels, label_weights, | |
| # and bbox_targets, bbox_weights | |
| cls_reg_targets = self.student.roi_head.bbox_head.get_targets( | |
| sampling_results, self.student.train_cfg.rcnn) | |
| selected_results_list = [] | |
| for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip( | |
| selected_bboxes, batch_data_samples, | |
| batch_info['homography_matrix'], batch_info['img_shape']): | |
| student_matrix = torch.tensor( | |
| data_samples.homography_matrix, device=teacher_matrix.device) | |
| homography_matrix = teacher_matrix @ student_matrix.inverse() | |
| projected_bboxes = bbox_project(bboxes, homography_matrix, | |
| teacher_img_shape) | |
| selected_results_list.append(InstanceData(bboxes=projected_bboxes)) | |
| with torch.no_grad(): | |
| results_list = self.teacher.roi_head.predict_bbox( | |
| batch_info['feat'], | |
| batch_info['metainfo'], | |
| selected_results_list, | |
| rcnn_test_cfg=None, | |
| rescale=False) | |
| bg_score = torch.cat( | |
| [results.scores[:, -1] for results in results_list]) | |
| # cls_reg_targets[0] is labels | |
| neg_inds = cls_reg_targets[ | |
| 0] == self.student.roi_head.bbox_head.num_classes | |
| # cls_reg_targets[1] is label_weights | |
| cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach() | |
| losses = self.student.roi_head.bbox_head.loss( | |
| bbox_results['cls_score'], bbox_results['bbox_pred'], rois, | |
| *cls_reg_targets) | |
| # cls_reg_targets[1] is label_weights | |
| losses['loss_cls'] = losses['loss_cls'] * len( | |
| cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0) | |
| return losses | |
| def rcnn_reg_loss_by_pseudo_instances( | |
| self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, | |
| batch_data_samples: SampleList) -> dict: | |
| """Calculate rcnn regression loss from a batch of inputs and pseudo | |
| data samples. | |
| Args: | |
| x (tuple[Tensor]): List of multi-level img features. | |
| unsup_rpn_results_list (list[:obj:`InstanceData`]): | |
| List of region proposals. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of rcnn | |
| regression loss components | |
| """ | |
| rpn_results_list = copy.deepcopy(unsup_rpn_results_list) | |
| reg_data_samples = copy.deepcopy(batch_data_samples) | |
| for data_samples in reg_data_samples: | |
| if data_samples.gt_instances.bboxes.shape[0] > 0: | |
| data_samples.gt_instances = data_samples.gt_instances[ | |
| data_samples.gt_instances.reg_uncs < | |
| self.semi_train_cfg.reg_pseudo_thr] | |
| roi_losses = self.student.roi_head.loss(x, rpn_results_list, | |
| reg_data_samples) | |
| return {'loss_bbox': roi_losses['loss_bbox']} | |
| def compute_uncertainty_with_aug( | |
| self, x: Tuple[Tensor], | |
| batch_data_samples: SampleList) -> List[Tensor]: | |
| """Compute uncertainty with augmented bboxes. | |
| Args: | |
| x (tuple[Tensor]): List of multi-level img features. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| Returns: | |
| list[Tensor]: A list of uncertainty for pseudo bboxes. | |
| """ | |
| auged_results_list = self.aug_box(batch_data_samples, | |
| self.semi_train_cfg.jitter_times, | |
| self.semi_train_cfg.jitter_scale) | |
| # flatten | |
| auged_results_list = [ | |
| InstanceData(bboxes=auged.reshape(-1, auged.shape[-1])) | |
| for auged in auged_results_list | |
| ] | |
| self.teacher.roi_head.test_cfg = None | |
| results_list = self.teacher.roi_head.predict( | |
| x, auged_results_list, batch_data_samples, rescale=False) | |
| self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn | |
| reg_channel = max( | |
| [results.bboxes.shape[-1] for results in results_list]) // 4 | |
| bboxes = [ | |
| results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1, | |
| results.bboxes.shape[-1]) | |
| if results.bboxes.numel() > 0 else results.bboxes.new_zeros( | |
| self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float() | |
| for results in results_list | |
| ] | |
| box_unc = [bbox.std(dim=0) for bbox in bboxes] | |
| bboxes = [bbox.mean(dim=0) for bbox in bboxes] | |
| labels = [ | |
| data_samples.gt_instances.labels | |
| for data_samples in batch_data_samples | |
| ] | |
| if reg_channel != 1: | |
| bboxes = [ | |
| bbox.reshape(bbox.shape[0], reg_channel, | |
| 4)[torch.arange(bbox.shape[0]), label] | |
| for bbox, label in zip(bboxes, labels) | |
| ] | |
| box_unc = [ | |
| unc.reshape(unc.shape[0], reg_channel, | |
| 4)[torch.arange(unc.shape[0]), label] | |
| for unc, label in zip(box_unc, labels) | |
| ] | |
| box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) | |
| for bbox in bboxes] | |
| box_unc = [ | |
| torch.mean( | |
| unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1) | |
| if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape) | |
| ] | |
| return box_unc | |
| def aug_box(batch_data_samples, times, frac): | |
| """Augment bboxes with jitter.""" | |
| def _aug_single(box): | |
| box_scale = box[:, 2:4] - box[:, :2] | |
| box_scale = ( | |
| box_scale.clamp(min=1)[:, None, :].expand(-1, 2, | |
| 2).reshape(-1, 4)) | |
| aug_scale = box_scale * frac # [n,4] | |
| offset = ( | |
| torch.randn(times, box.shape[0], 4, device=box.device) * | |
| aug_scale[None, ...]) | |
| new_box = box.clone()[None, ...].expand(times, box.shape[0], | |
| -1) + offset | |
| return new_box | |
| return [ | |
| _aug_single(data_samples.gt_instances.bboxes) | |
| for data_samples in batch_data_samples | |
| ] | |