Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import warnings | |
| from typing import Dict, List, Optional, Sequence, Tuple, Union | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmengine.config import Config, ConfigDict | |
| from mmengine.infer.infer import ModelType | |
| from mmengine.model import revert_sync_batchnorm | |
| from mmengine.registry import init_default_scope | |
| from mmengine.structures import InstanceData | |
| from rich.progress import track | |
| from mmpose.evaluation.functional import nms | |
| from mmpose.registry import DATASETS, INFERENCERS | |
| from mmpose.structures import merge_data_samples | |
| from .base_mmpose_inferencer import BaseMMPoseInferencer | |
| from .utils import default_det_models | |
| try: | |
| from mmdet.apis.det_inferencer import DetInferencer | |
| has_mmdet = True | |
| except (ImportError, ModuleNotFoundError): | |
| has_mmdet = False | |
| InstanceList = List[InstanceData] | |
| InputType = Union[str, np.ndarray] | |
| InputsType = Union[InputType, Sequence[InputType]] | |
| PredType = Union[InstanceData, InstanceList] | |
| ImgType = Union[np.ndarray, Sequence[np.ndarray]] | |
| ConfigType = Union[Config, ConfigDict] | |
| ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] | |
| class Pose2DInferencer(BaseMMPoseInferencer): | |
| """The inferencer for 2D pose estimation. | |
| Args: | |
| model (str, optional): Pretrained 2D pose estimation algorithm. | |
| It's the path to the config file or the model name defined in | |
| metafile. For example, it could be: | |
| - model alias, e.g. ``'body'``, | |
| - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``, | |
| - config path | |
| Defaults to ``None``. | |
| weights (str, optional): Path to the checkpoint. If it is not | |
| specified and "model" is a model name of metafile, the weights | |
| will be loaded from metafile. Defaults to None. | |
| device (str, optional): Device to run inference. If None, the | |
| available device will be automatically used. Defaults to None. | |
| scope (str, optional): The scope of the model. Defaults to "mmpose". | |
| det_model (str, optional): Config path or alias of detection model. | |
| Defaults to None. | |
| det_weights (str, optional): Path to the checkpoints of detection | |
| model. Defaults to None. | |
| det_cat_ids (int or list[int], optional): Category id for | |
| detection model. Defaults to None. | |
| output_heatmaps (bool, optional): Flag to visualize predicted | |
| heatmaps. If set to None, the default setting from the model | |
| config will be used. Default is None. | |
| """ | |
| preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} | |
| forward_kwargs: set = set() | |
| visualize_kwargs: set = { | |
| 'return_vis', | |
| 'show', | |
| 'wait_time', | |
| 'draw_bbox', | |
| 'radius', | |
| 'thickness', | |
| 'kpt_thr', | |
| 'vis_out_dir', | |
| } | |
| postprocess_kwargs: set = {'pred_out_dir'} | |
| def __init__(self, | |
| model: Union[ModelType, str], | |
| weights: Optional[str] = None, | |
| device: Optional[str] = None, | |
| scope: Optional[str] = 'mmpose', | |
| det_model: Optional[Union[ModelType, str]] = None, | |
| det_weights: Optional[str] = None, | |
| det_cat_ids: Optional[Union[int, Tuple]] = None, | |
| output_heatmaps: Optional[bool] = None) -> None: | |
| init_default_scope(scope) | |
| super().__init__( | |
| model=model, weights=weights, device=device, scope=scope) | |
| self.model = revert_sync_batchnorm(self.model) | |
| if output_heatmaps is not None: | |
| self.model.test_cfg['output_heatmaps'] = output_heatmaps | |
| # assign dataset metainfo to self.visualizer | |
| self.visualizer.set_dataset_meta(self.model.dataset_meta) | |
| # initialize detector for top-down models | |
| if self.cfg.data_mode == 'topdown': | |
| object_type = DATASETS.get(self.cfg.dataset_type).__module__.split( | |
| 'datasets.')[-1].split('.')[0].lower() | |
| if det_model in ('whole_image', 'whole-image') or \ | |
| (det_model is None and | |
| object_type not in default_det_models): | |
| self.detector = None | |
| else: | |
| det_scope = 'mmdet' | |
| if det_model is None: | |
| det_info = default_det_models[object_type] | |
| det_model, det_weights, det_cat_ids = det_info[ | |
| 'model'], det_info['weights'], det_info['cat_ids'] | |
| elif os.path.exists(det_model): | |
| det_cfg = Config.fromfile(det_model) | |
| det_scope = det_cfg.default_scope | |
| if has_mmdet: | |
| self.detector = DetInferencer( | |
| det_model, det_weights, device=device, scope=det_scope) | |
| else: | |
| raise RuntimeError( | |
| 'MMDetection (v3.0.0 or above) is required to build ' | |
| 'inferencers for top-down pose estimation models.') | |
| if isinstance(det_cat_ids, (tuple, list)): | |
| self.det_cat_ids = det_cat_ids | |
| else: | |
| self.det_cat_ids = (det_cat_ids, ) | |
| self._video_input = False | |
| def preprocess_single(self, | |
| input: InputType, | |
| index: int, | |
| bbox_thr: float = 0.3, | |
| nms_thr: float = 0.3, | |
| bboxes: Union[List[List], List[np.ndarray], | |
| np.ndarray] = []): | |
| """Process a single input into a model-feedable format. | |
| Args: | |
| input (InputType): Input given by user. | |
| index (int): index of the input | |
| bbox_thr (float): threshold for bounding box detection. | |
| Defaults to 0.3. | |
| nms_thr (float): IoU threshold for bounding box NMS. | |
| Defaults to 0.3. | |
| Yields: | |
| Any: Data processed by the ``pipeline`` and ``collate_fn``. | |
| """ | |
| if isinstance(input, str): | |
| data_info = dict(img_path=input) | |
| else: | |
| data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0')) | |
| data_info.update(self.model.dataset_meta) | |
| if self.cfg.data_mode == 'topdown': | |
| if self.detector is not None: | |
| det_results = self.detector( | |
| input, return_datasample=True)['predictions'] | |
| pred_instance = det_results[0].pred_instances.cpu().numpy() | |
| bboxes = np.concatenate( | |
| (pred_instance.bboxes, pred_instance.scores[:, None]), | |
| axis=1) | |
| label_mask = np.zeros(len(bboxes), dtype=np.uint8) | |
| for cat_id in self.det_cat_ids: | |
| label_mask = np.logical_or(label_mask, | |
| pred_instance.labels == cat_id) | |
| bboxes = bboxes[np.logical_and( | |
| label_mask, pred_instance.scores > bbox_thr)] | |
| bboxes = bboxes[nms(bboxes, nms_thr)] | |
| data_infos = [] | |
| if len(bboxes) > 0: | |
| for bbox in bboxes: | |
| inst = data_info.copy() | |
| inst['bbox'] = bbox[None, :4] | |
| inst['bbox_score'] = bbox[4:5] | |
| data_infos.append(self.pipeline(inst)) | |
| else: | |
| inst = data_info.copy() | |
| # get bbox from the image size | |
| if isinstance(input, str): | |
| input = mmcv.imread(input) | |
| h, w = input.shape[:2] | |
| inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32) | |
| inst['bbox_score'] = np.ones(1, dtype=np.float32) | |
| data_infos.append(self.pipeline(inst)) | |
| else: # bottom-up | |
| data_infos = [self.pipeline(data_info)] | |
| return data_infos | |
| def forward(self, inputs: Union[dict, tuple], bbox_thr=-1): | |
| data_samples = super().forward(inputs) | |
| if self.cfg.data_mode == 'topdown': | |
| data_samples = [merge_data_samples(data_samples)] | |
| if bbox_thr > 0: | |
| for ds in data_samples: | |
| if 'bbox_scores' in ds.pred_instances: | |
| ds.pred_instances = ds.pred_instances[ | |
| ds.pred_instances.bbox_scores > bbox_thr] | |
| return data_samples | |
| def __call__( | |
| self, | |
| inputs: InputsType, | |
| return_datasample: bool = False, | |
| batch_size: int = 1, | |
| out_dir: Optional[str] = None, | |
| **kwargs, | |
| ) -> dict: | |
| """Call the inferencer. | |
| Args: | |
| inputs (InputsType): Inputs for the inferencer. | |
| return_datasample (bool): Whether to return results as | |
| :obj:`BaseDataElement`. Defaults to False. | |
| batch_size (int): Batch size. Defaults to 1. | |
| out_dir (str, optional): directory to save visualization | |
| results and predictions. Will be overoden if vis_out_dir or | |
| pred_out_dir are given. Defaults to None | |
| **kwargs: Key words arguments passed to :meth:`preprocess`, | |
| :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
| Each key in kwargs should be in the corresponding set of | |
| ``preprocess_kwargs``, ``forward_kwargs``, | |
| ``visualize_kwargs`` and ``postprocess_kwargs``. | |
| Returns: | |
| dict: Inference and visualization results. | |
| """ | |
| if out_dir is not None: | |
| if 'vis_out_dir' not in kwargs: | |
| kwargs['vis_out_dir'] = f'{out_dir}/visualizations' | |
| if 'pred_out_dir' not in kwargs: | |
| kwargs['pred_out_dir'] = f'{out_dir}/predictions' | |
| ( | |
| preprocess_kwargs, | |
| forward_kwargs, | |
| visualize_kwargs, | |
| postprocess_kwargs, | |
| ) = self._dispatch_kwargs(**kwargs) | |
| # preprocessing | |
| if isinstance(inputs, str) and inputs.startswith('webcam'): | |
| inputs = self._get_webcam_inputs(inputs) | |
| batch_size = 1 | |
| if not visualize_kwargs.get('show', False): | |
| warnings.warn('The display mode is closed when using webcam ' | |
| 'input. It will be turned on automatically.') | |
| visualize_kwargs['show'] = True | |
| else: | |
| inputs = self._inputs_to_list(inputs) | |
| forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) | |
| inputs = self.preprocess( | |
| inputs, batch_size=batch_size, **preprocess_kwargs) | |
| preds = [] | |
| if not hasattr(self, 'detector'): | |
| inputs = track(inputs, description='Inference') | |
| for proc_inputs, ori_inputs in inputs: | |
| preds = self.forward(proc_inputs, **forward_kwargs) | |
| visualization = self.visualize(ori_inputs, preds, | |
| **visualize_kwargs) | |
| results = self.postprocess(preds, visualization, return_datasample, | |
| **postprocess_kwargs) | |
| yield results | |
| if self._video_input: | |
| self._finalize_video_processing( | |
| postprocess_kwargs.get('pred_out_dir', '')) | |