Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| from mmcv.utils import print_log | |
| from mmdet.core import eval_map, eval_recalls | |
| from .builder import DATASETS | |
| from .xml_style import XMLDataset | |
| class VOCDataset(XMLDataset): | |
| CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', | |
| 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', | |
| 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', | |
| 'tvmonitor') | |
| def __init__(self, **kwargs): | |
| super(VOCDataset, self).__init__(**kwargs) | |
| if 'VOC2007' in self.img_prefix: | |
| self.year = 2007 | |
| elif 'VOC2012' in self.img_prefix: | |
| self.year = 2012 | |
| else: | |
| raise ValueError('Cannot infer dataset year from img_prefix') | |
| def evaluate(self, | |
| results, | |
| metric='mAP', | |
| logger=None, | |
| proposal_nums=(100, 300, 1000), | |
| iou_thr=0.5, | |
| scale_ranges=None): | |
| """Evaluate in VOC protocol. | |
| Args: | |
| results (list[list | tuple]): Testing results of the dataset. | |
| metric (str | list[str]): Metrics to be evaluated. Options are | |
| 'mAP', 'recall'. | |
| logger (logging.Logger | str, optional): Logger used for printing | |
| related information during evaluation. Default: None. | |
| proposal_nums (Sequence[int]): Proposal number used for evaluating | |
| recalls, such as recall@100, recall@1000. | |
| Default: (100, 300, 1000). | |
| iou_thr (float | list[float]): IoU threshold. Default: 0.5. | |
| scale_ranges (list[tuple], optional): Scale ranges for evaluating | |
| mAP. If not specified, all bounding boxes would be included in | |
| evaluation. Default: None. | |
| Returns: | |
| dict[str, float]: AP/recall metrics. | |
| """ | |
| if not isinstance(metric, str): | |
| assert len(metric) == 1 | |
| metric = metric[0] | |
| allowed_metrics = ['mAP', 'recall'] | |
| if metric not in allowed_metrics: | |
| raise KeyError(f'metric {metric} is not supported') | |
| annotations = [self.get_ann_info(i) for i in range(len(self))] | |
| eval_results = OrderedDict() | |
| iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr | |
| if metric == 'mAP': | |
| assert isinstance(iou_thrs, list) | |
| if self.year == 2007: | |
| ds_name = 'voc07' | |
| else: | |
| ds_name = self.CLASSES | |
| mean_aps = [] | |
| for iou_thr in iou_thrs: | |
| print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') | |
| mean_ap, _ = eval_map( | |
| results, | |
| annotations, | |
| scale_ranges=None, | |
| iou_thr=iou_thr, | |
| dataset=ds_name, | |
| logger=logger) | |
| mean_aps.append(mean_ap) | |
| eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) | |
| eval_results['mAP'] = sum(mean_aps) / len(mean_aps) | |
| elif metric == 'recall': | |
| gt_bboxes = [ann['bboxes'] for ann in annotations] | |
| recalls = eval_recalls( | |
| gt_bboxes, results, proposal_nums, iou_thr, logger=logger) | |
| for i, num in enumerate(proposal_nums): | |
| for j, iou in enumerate(iou_thr): | |
| eval_results[f'recall@{num}@{iou}'] = recalls[i, j] | |
| if recalls.shape[1] > 1: | |
| ar = recalls.mean(axis=1) | |
| for i, num in enumerate(proposal_nums): | |
| eval_results[f'AR@{num}'] = ar[i] | |
| return eval_results | |