File size: 4,670 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from argparse import ArgumentParser

import mmcv
from mmengine.config import Config
from mmengine.logging import MMLogger
from mmengine.utils import mkdir_or_exist

from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.utils import register_all_modules


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint_root', help='Checkpoint file root path')
    parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
    parser.add_argument('--aug', action='store_true', help='aug test')
    parser.add_argument('--model-name', help='model name to inference')
    parser.add_argument('--show', action='store_true', help='show results')
    parser.add_argument('--out-dir', default=None, help='Dir to output file')
    parser.add_argument(
        '--wait-time',
        type=float,
        default=1,
        help='the interval of show (s), 0 is block')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='coco',
        choices=['coco', 'voc', 'citys', 'random'],
        help='Color palette used for visualization')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    args = parser.parse_args()
    return args


def inference_model(config_name, checkpoint, visualizer, args, logger=None):
    cfg = Config.fromfile(config_name)
    if args.aug:
        raise NotImplementedError()

    model = init_detector(
        cfg, checkpoint, palette=args.palette, device=args.device)
    visualizer.dataset_meta = model.dataset_meta

    # test a single image
    result = inference_detector(model, args.img)

    # show the results
    if args.show or args.out_dir is not None:
        img = mmcv.imread(args.img)
        img = mmcv.imconvert(img, 'bgr', 'rgb')
        out_file = None
        if args.out_dir is not None:
            out_dir = args.out_dir
            mkdir_or_exist(out_dir)

            out_file = osp.join(
                out_dir,
                config_name.split('/')[-1].replace('py', 'jpg'))

        visualizer.add_datasample(
            'result',
            img,
            data_sample=result,
            draw_gt=False,
            show=args.show,
            wait_time=args.wait_time,
            out_file=out_file,
            pred_score_thr=args.score_thr)

    return result


# Sample test whether the inference code is correct
def main(args):
    # register all modules in mmdet into the registries
    register_all_modules()

    config = Config.fromfile(args.config)

    # init visualizer
    visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer')
    visualizer = VISUALIZERS.build(visualizer_cfg)

    # test single model
    if args.model_name:
        if args.model_name in config:
            model_infos = config[args.model_name]
            if not isinstance(model_infos, list):
                model_infos = [model_infos]
            model_info = model_infos[0]
            config_name = model_info['config'].strip()
            print(f'processing: {config_name}', flush=True)
            checkpoint = osp.join(args.checkpoint_root,
                                  model_info['checkpoint'].strip())
            # build the model from a config file and a checkpoint file
            inference_model(config_name, checkpoint, visualizer, args)
            return
        else:
            raise RuntimeError('model name input error.')

    # test all model
    logger = MMLogger.get_instance(
        name='MMLogger',
        log_file='benchmark_test_image.log',
        log_level=logging.ERROR)

    for model_key in config:
        model_infos = config[model_key]
        if not isinstance(model_infos, list):
            model_infos = [model_infos]
        for model_info in model_infos:
            print('processing: ', model_info['config'], flush=True)
            config_name = model_info['config'].strip()
            checkpoint = osp.join(args.checkpoint_root,
                                  model_info['checkpoint'].strip())
            try:
                # build the model from a config file and a checkpoint file
                inference_model(config_name, checkpoint, visualizer, args,
                                logger)
            except Exception as e:
                logger.error(f'{config_name} " : {repr(e)}')


if __name__ == '__main__':
    args = parse_args()
    main(args)