File size: 4,670 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from argparse import ArgumentParser
import mmcv
from mmengine.config import Config
from mmengine.logging import MMLogger
from mmengine.utils import mkdir_or_exist
from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.utils import register_all_modules
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
parser.add_argument('--aug', action='store_true', help='aug test')
parser.add_argument('--model-name', help='model name to inference')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument('--out-dir', default=None, help='Dir to output file')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='the interval of show (s), 0 is block')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args
def inference_model(config_name, checkpoint, visualizer, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
raise NotImplementedError()
model = init_detector(
cfg, checkpoint, palette=args.palette, device=args.device)
visualizer.dataset_meta = model.dataset_meta
# test a single image
result = inference_detector(model, args.img)
# show the results
if args.show or args.out_dir is not None:
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None
if args.out_dir is not None:
out_dir = args.out_dir
mkdir_or_exist(out_dir)
out_file = osp.join(
out_dir,
config_name.split('/')[-1].replace('py', 'jpg'))
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=args.wait_time,
out_file=out_file,
pred_score_thr=args.score_thr)
return result
# Sample test whether the inference code is correct
def main(args):
# register all modules in mmdet into the registries
register_all_modules()
config = Config.fromfile(args.config)
# init visualizer
visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer')
visualizer = VISUALIZERS.build(visualizer_cfg)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
model_info = model_infos[0]
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, visualizer, args)
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = MMLogger.get_instance(
name='MMLogger',
log_file='benchmark_test_image.log',
log_level=logging.ERROR)
for model_key in config:
model_infos = config[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_name = model_info['config'].strip()
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, visualizer, args,
logger)
except Exception as e:
logger.error(f'{config_name} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)
|