Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
import torch
from mmdet.datasets import OpenImagesDataset
from mmdet.evaluation import OpenImagesMetric
from mmdet.utils import register_all_modules
class TestOpenImagesMetric(unittest.TestCase):
def _create_dummy_results(self):
bboxes = np.array([[23.2172, 31.7541, 987.3413, 357.8443],
[100, 120, 130, 150], [150, 160, 190, 200],
[250, 260, 350, 360]])
scores = np.array([1.0, 0.98, 0.96, 0.95])
labels = np.array([0, 0, 0, 0])
return dict(
bboxes=torch.from_numpy(bboxes),
scores=torch.from_numpy(scores),
labels=torch.from_numpy(labels))
def test_init(self):
# test invalid iou_thrs
with self.assertRaises(AssertionError):
OpenImagesMetric(iou_thrs={'a', 0.5}, ioa_thrs={'b', 0.5})
# test ioa and iou_thrs length not equal
with self.assertRaises(AssertionError):
OpenImagesMetric(iou_thrs=[0.5, 0.75], ioa_thrs=[0.5])
metric = OpenImagesMetric(iou_thrs=0.6)
self.assertEqual(metric.iou_thrs, [0.6])
def test_eval(self):
register_all_modules()
dataset = OpenImagesDataset(
data_root='tests/data/OpenImages/',
ann_file='annotations/oidv6-train-annotations-bbox.csv',
data_prefix=dict(img='OpenImages/train/'),
label_file='annotations/class-descriptions-boxable.csv',
hierarchy_file='annotations/bbox_labels_600_hierarchy.json',
meta_file='annotations/image-metas.pkl',
pipeline=[
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'instances'))
])
dataset.full_init()
data_sample = dataset[0]['data_samples'].to_dict()
data_sample['pred_instances'] = self._create_dummy_results()
metric = OpenImagesMetric()
metric.dataset_meta = dataset.metainfo
metric.process({}, [data_sample])
results = metric.evaluate(size=len(dataset))
targets = {'openimages/AP50': 1.0, 'openimages/mAP': 1.0}
self.assertDictEqual(results, targets)
# test multi-threshold
metric = OpenImagesMetric(iou_thrs=[0.1, 0.5], ioa_thrs=[0.1, 0.5])
metric.dataset_meta = dataset.metainfo
metric.process({}, [data_sample])
results = metric.evaluate(size=len(dataset))
targets = {
'openimages/AP10': 1.0,
'openimages/AP50': 1.0,
'openimages/mAP': 1.0
}
self.assertDictEqual(results, targets)