Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, | |
| reweight_loss_dict) | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import SampleList | |
| from mmdet.structures.bbox import bbox_project | |
| from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
| from .base import BaseDetector | |
| class SemiBaseDetector(BaseDetector): | |
| """Base class for semi-supervised detectors. | |
| Semi-supervised detectors typically consisting of a teacher model | |
| updated by exponential moving average and a student model updated | |
| by gradient descent. | |
| Args: | |
| detector (:obj:`ConfigDict` or dict): The detector config. | |
| semi_train_cfg (:obj:`ConfigDict` or dict, optional): | |
| The semi-supervised training config. | |
| semi_test_cfg (:obj:`ConfigDict` or dict, optional): | |
| The semi-supervised testing config. | |
| data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of | |
| :class:`DetDataPreprocessor` to process the input data. | |
| Defaults to None. | |
| init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
| list[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| detector: ConfigType, | |
| semi_train_cfg: OptConfigType = None, | |
| semi_test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None) -> None: | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.student = MODELS.build(detector) | |
| self.teacher = MODELS.build(detector) | |
| self.semi_train_cfg = semi_train_cfg | |
| self.semi_test_cfg = semi_test_cfg | |
| if self.semi_train_cfg.get('freeze_teacher', True) is True: | |
| self.freeze(self.teacher) | |
| def freeze(model: nn.Module): | |
| """Freeze the model.""" | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| def loss(self, multi_batch_inputs: Dict[str, Tensor], | |
| multi_batch_data_samples: Dict[str, SampleList]) -> dict: | |
| """Calculate losses from multi-branch inputs and data samples. | |
| Args: | |
| multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch | |
| input images, each value with shape (N, C, H, W). | |
| Each value should usually be mean centered and std scaled. | |
| multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): | |
| The dict of multi-branch data samples. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| losses = dict() | |
| losses.update(**self.loss_by_gt_instances( | |
| multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) | |
| origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( | |
| multi_batch_inputs['unsup_teacher'], | |
| multi_batch_data_samples['unsup_teacher']) | |
| multi_batch_data_samples[ | |
| 'unsup_student'] = self.project_pseudo_instances( | |
| origin_pseudo_data_samples, | |
| multi_batch_data_samples['unsup_student']) | |
| losses.update(**self.loss_by_pseudo_instances( | |
| multi_batch_inputs['unsup_student'], | |
| multi_batch_data_samples['unsup_student'], batch_info)) | |
| return losses | |
| def loss_by_gt_instances(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> dict: | |
| """Calculate losses from a batch of inputs and ground-truth data | |
| samples. | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| losses = self.student.loss(batch_inputs, batch_data_samples) | |
| sup_weight = self.semi_train_cfg.get('sup_weight', 1.) | |
| return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) | |
| def loss_by_pseudo_instances(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| batch_info: Optional[dict] = None) -> dict: | |
| """Calculate losses from a batch of inputs and pseudo data samples. | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
| which are `pseudo_instance` or `pseudo_panoptic_seg` | |
| or `pseudo_sem_seg` in fact. | |
| batch_info (dict): Batch information of teacher model | |
| forward propagation process. Defaults to None. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| batch_data_samples = filter_gt_instances( | |
| batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) | |
| losses = self.student.loss(batch_inputs, batch_data_samples) | |
| pseudo_instances_num = sum([ | |
| len(data_samples.gt_instances) | |
| for data_samples in batch_data_samples | |
| ]) | |
| unsup_weight = self.semi_train_cfg.get( | |
| 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. | |
| return rename_loss_dict('unsup_', | |
| reweight_loss_dict(losses, unsup_weight)) | |
| def get_pseudo_instances( | |
| self, batch_inputs: Tensor, batch_data_samples: SampleList | |
| ) -> Tuple[SampleList, Optional[dict]]: | |
| """Get pseudo instances from teacher model.""" | |
| self.teacher.eval() | |
| results_list = self.teacher.predict( | |
| batch_inputs, batch_data_samples, rescale=False) | |
| batch_info = {} | |
| for data_samples, results in zip(batch_data_samples, results_list): | |
| data_samples.gt_instances = results.pred_instances | |
| data_samples.gt_instances.bboxes = bbox_project( | |
| data_samples.gt_instances.bboxes, | |
| torch.from_numpy(data_samples.homography_matrix).inverse().to( | |
| self.data_preprocessor.device), data_samples.ori_shape) | |
| return batch_data_samples, batch_info | |
| def project_pseudo_instances(self, batch_pseudo_instances: SampleList, | |
| batch_data_samples: SampleList) -> SampleList: | |
| """Project pseudo instances.""" | |
| for pseudo_instances, data_samples in zip(batch_pseudo_instances, | |
| batch_data_samples): | |
| data_samples.gt_instances = copy.deepcopy( | |
| pseudo_instances.gt_instances) | |
| data_samples.gt_instances.bboxes = bbox_project( | |
| data_samples.gt_instances.bboxes, | |
| torch.tensor(data_samples.homography_matrix).to( | |
| self.data_preprocessor.device), data_samples.img_shape) | |
| wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2)) | |
| return filter_gt_instances(batch_data_samples, wh_thr=wh_thr) | |
| def predict(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Return the detection results of the | |
| input images. The returns value is DetDataSample, | |
| which usually contain 'pred_instances'. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| """ | |
| if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': | |
| return self.teacher( | |
| batch_inputs, batch_data_samples, mode='predict') | |
| else: | |
| return self.student( | |
| batch_inputs, batch_data_samples, mode='predict') | |
| def _forward(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> SampleList: | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| Returns: | |
| tuple: A tuple of features from ``rpn_head`` and ``roi_head`` | |
| forward. | |
| """ | |
| if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': | |
| return self.teacher( | |
| batch_inputs, batch_data_samples, mode='tensor') | |
| else: | |
| return self.student( | |
| batch_inputs, batch_data_samples, mode='tensor') | |
| def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: | |
| """Extract features. | |
| Args: | |
| batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). | |
| Returns: | |
| tuple[Tensor]: Multi-level features that may have | |
| different resolutions. | |
| """ | |
| if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': | |
| return self.teacher.extract_feat(batch_inputs) | |
| else: | |
| return self.student.extract_feat(batch_inputs) | |
| def _load_from_state_dict(self, state_dict: dict, prefix: str, | |
| local_metadata: dict, strict: bool, | |
| missing_keys: Union[List[str], str], | |
| unexpected_keys: Union[List[str], str], | |
| error_msgs: Union[List[str], str]) -> None: | |
| """Add teacher and student prefixes to model parameter names.""" | |
| if not any([ | |
| 'student' in key or 'teacher' in key | |
| for key in state_dict.keys() | |
| ]): | |
| keys = list(state_dict.keys()) | |
| state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) | |
| state_dict.update({'student.' + k: state_dict[k] for k in keys}) | |
| for k in keys: | |
| state_dict.pop(k) | |
| return super()._load_from_state_dict( | |
| state_dict, | |
| prefix, | |
| local_metadata, | |
| strict, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ) | |