|
|
|
import unittest |
|
|
|
from mmdet.datasets import VOCDataset |
|
|
|
|
|
class TestVOCDataset(unittest.TestCase): |
|
|
|
def test_voc2007_init(self): |
|
dataset = VOCDataset( |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2007/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2007/'), |
|
filter_cfg=dict( |
|
filter_empty_gt=True, min_size=32, bbox_min_size=32), |
|
pipeline=[]) |
|
dataset.full_init() |
|
self.assertEqual(len(dataset), 1) |
|
|
|
data_list = dataset.load_data_list() |
|
self.assertEqual(len(data_list), 1) |
|
self.assertEqual(len(data_list[0]['instances']), 2) |
|
self.assertEqual(dataset.get_cat_ids(0), [11, 14]) |
|
|
|
def test_voc2012_init(self): |
|
dataset = VOCDataset( |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2012/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2012/'), |
|
filter_cfg=dict(filter_empty_gt=True, min_size=32), |
|
pipeline=[]) |
|
dataset.full_init() |
|
self.assertEqual(len(dataset), 1) |
|
|
|
data_list = dataset.load_data_list() |
|
self.assertEqual(len(data_list), 1) |
|
self.assertEqual(len(data_list[0]['instances']), 1) |
|
self.assertEqual(dataset.get_cat_ids(0), [18]) |
|
|