Spaces:
Build error
Build error
| import copy | |
| import warnings | |
| from mmcv.cnn import VGG | |
| from mmcv.runner.hooks import HOOKS, Hook | |
| from mmdet.datasets.builder import PIPELINES | |
| from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile | |
| from mmdet.models.dense_heads import GARPNHead, RPNHead | |
| from mmdet.models.roi_heads.mask_heads import FusedSemanticHead | |
| def replace_ImageToTensor(pipelines): | |
| """Replace the ImageToTensor transform in a data pipeline to | |
| DefaultFormatBundle, which is normally useful in batch inference. | |
| Args: | |
| pipelines (list[dict]): Data pipeline configs. | |
| Returns: | |
| list: The new pipeline list with all ImageToTensor replaced by | |
| DefaultFormatBundle. | |
| Examples: | |
| >>> pipelines = [ | |
| ... dict(type='LoadImageFromFile'), | |
| ... dict( | |
| ... type='MultiScaleFlipAug', | |
| ... img_scale=(1333, 800), | |
| ... flip=False, | |
| ... transforms=[ | |
| ... dict(type='Resize', keep_ratio=True), | |
| ... dict(type='RandomFlip'), | |
| ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), | |
| ... dict(type='Pad', size_divisor=32), | |
| ... dict(type='ImageToTensor', keys=['img']), | |
| ... dict(type='Collect', keys=['img']), | |
| ... ]) | |
| ... ] | |
| >>> expected_pipelines = [ | |
| ... dict(type='LoadImageFromFile'), | |
| ... dict( | |
| ... type='MultiScaleFlipAug', | |
| ... img_scale=(1333, 800), | |
| ... flip=False, | |
| ... transforms=[ | |
| ... dict(type='Resize', keep_ratio=True), | |
| ... dict(type='RandomFlip'), | |
| ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), | |
| ... dict(type='Pad', size_divisor=32), | |
| ... dict(type='DefaultFormatBundle'), | |
| ... dict(type='Collect', keys=['img']), | |
| ... ]) | |
| ... ] | |
| >>> assert expected_pipelines == replace_ImageToTensor(pipelines) | |
| """ | |
| pipelines = copy.deepcopy(pipelines) | |
| for i, pipeline in enumerate(pipelines): | |
| if pipeline['type'] == 'MultiScaleFlipAug': | |
| assert 'transforms' in pipeline | |
| pipeline['transforms'] = replace_ImageToTensor( | |
| pipeline['transforms']) | |
| elif pipeline['type'] == 'ImageToTensor': | |
| warnings.warn( | |
| '"ImageToTensor" pipeline is replaced by ' | |
| '"DefaultFormatBundle" for batch inference. It is ' | |
| 'recommended to manually replace it in the test ' | |
| 'data pipeline in your config file.', UserWarning) | |
| pipelines[i] = {'type': 'DefaultFormatBundle'} | |
| return pipelines | |
| def get_loading_pipeline(pipeline): | |
| """Only keep loading image and annotations related configuration. | |
| Args: | |
| pipeline (list[dict]): Data pipeline configs. | |
| Returns: | |
| list[dict]: The new pipeline list with only keep | |
| loading image and annotations related configuration. | |
| Examples: | |
| >>> pipelines = [ | |
| ... dict(type='LoadImageFromFile'), | |
| ... dict(type='LoadAnnotations', with_bbox=True), | |
| ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), | |
| ... dict(type='RandomFlip', flip_ratio=0.5), | |
| ... dict(type='Normalize', **img_norm_cfg), | |
| ... dict(type='Pad', size_divisor=32), | |
| ... dict(type='DefaultFormatBundle'), | |
| ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) | |
| ... ] | |
| >>> expected_pipelines = [ | |
| ... dict(type='LoadImageFromFile'), | |
| ... dict(type='LoadAnnotations', with_bbox=True) | |
| ... ] | |
| >>> assert expected_pipelines ==\ | |
| ... get_loading_pipeline(pipelines) | |
| """ | |
| loading_pipeline_cfg = [] | |
| for cfg in pipeline: | |
| obj_cls = PIPELINES.get(cfg['type']) | |
| # TODO:use more elegant way to distinguish loading modules | |
| if obj_cls is not None and obj_cls in (LoadImageFromFile, | |
| LoadAnnotations): | |
| loading_pipeline_cfg.append(cfg) | |
| assert len(loading_pipeline_cfg) == 2, \ | |
| 'The data pipeline in your config file must include ' \ | |
| 'loading image and annotations related pipeline.' | |
| return loading_pipeline_cfg | |
| class NumClassCheckHook(Hook): | |
| def _check_head(self, runner): | |
| """Check whether the `num_classes` in head matches the length of | |
| `CLASSSES` in `dataset`. | |
| Args: | |
| runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
| """ | |
| model = runner.model | |
| dataset = runner.data_loader.dataset | |
| if dataset.CLASSES is None: | |
| runner.logger.warning( | |
| f'Please set `CLASSES` ' | |
| f'in the {dataset.__class__.__name__} and' | |
| f'check if it is consistent with the `num_classes` ' | |
| f'of head') | |
| else: | |
| for name, module in model.named_modules(): | |
| if hasattr(module, 'num_classes') and not isinstance( | |
| module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): | |
| assert module.num_classes == len(dataset.CLASSES), \ | |
| (f'The `num_classes` ({module.num_classes}) in ' | |
| f'{module.__class__.__name__} of ' | |
| f'{model.__class__.__name__} does not matches ' | |
| f'the length of `CLASSES` ' | |
| f'{len(dataset.CLASSES)}) in ' | |
| f'{dataset.__class__.__name__}') | |
| def before_train_epoch(self, runner): | |
| """Check whether the training dataset is compatible with head. | |
| Args: | |
| runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
| """ | |
| self._check_head(runner) | |
| def before_val_epoch(self, runner): | |
| """Check whether the dataset in val epoch is compatible with head. | |
| Args: | |
| runner (obj:`EpochBasedRunner`): Epoch based Runner. | |
| """ | |
| self._check_head(runner) | |