File size: 2,985 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os.path as osp
import tempfile
from copy import deepcopy
import pytest
from mmengine.config import Config
from mmdet.utils import replace_cfg_vals
def test_replace_cfg_vals():
temp_file = tempfile.NamedTemporaryFile()
cfg_path = f'{temp_file.name}.py'
with open(cfg_path, 'w') as f:
f.write('configs')
ori_cfg_dict = dict()
ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name)
ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}'
ori_cfg_dict['percent'] = 5
ori_cfg_dict['fold'] = 1
ori_cfg_dict['model_wrapper'] = dict(
type='SoftTeacher', detector='${model}')
ori_cfg_dict['model'] = dict(
type='FasterRCNN',
backbone=dict(type='ResNet'),
neck=dict(type='FPN'),
rpn_head=dict(type='RPNHead'),
roi_head=dict(type='StandardRoIHead'),
train_cfg=dict(
rpn=dict(
assigner=dict(type='MaxIoUAssigner'),
sampler=dict(type='RandomSampler'),
),
rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)),
rcnn=dict(
assigner=dict(type='MaxIoUAssigner'),
sampler=dict(type='RandomSampler'),
),
),
test_cfg=dict(
rpn=dict(nms=dict(type='nms', iou_threshold=0.7)),
rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)),
),
)
ori_cfg_dict['iou_threshold'] = dict(
rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}',
test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}',
test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}',
)
ori_cfg_dict['str'] = 'Hello, world!'
ori_cfg_dict['dict'] = {'Hello': 'world!'}
ori_cfg_dict['list'] = [
'Hello, world!',
]
ori_cfg_dict['tuple'] = ('Hello, world!', )
ori_cfg_dict['test_str'] = 'xxx${str}xxx'
ori_cfg = Config(ori_cfg_dict, filename=cfg_path)
updated_cfg = replace_cfg_vals(deepcopy(ori_cfg))
assert updated_cfg.work_dir \
== f'work_dirs/{osp.basename(temp_file.name)}/5/1'
assert updated_cfg.model.detector == ori_cfg.model
assert updated_cfg.iou_threshold.rpn_proposal_nms \
== ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold
assert updated_cfg.test_str == 'xxxHello, world!xxx'
ori_cfg_dict['test_dict'] = 'xxx${dict}xxx'
ori_cfg_dict['test_list'] = 'xxx${list}xxx'
ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx'
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_dict'] = 'xxx${dict}xxx'
updated_cfg = replace_cfg_vals(cfg)
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_list'] = 'xxx${list}xxx'
updated_cfg = replace_cfg_vals(cfg)
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_tuple'] = 'xxx${tuple}xxx'
updated_cfg = replace_cfg_vals(cfg)
|