# Copyright (c) OpenMMLab. All rights reserved. import argparse import logging import os import os.path as osp from functools import partial import mmcv import mmcv_custom import mmseg_custom import torch.multiprocessing as mp from mmdeploy.apis import (create_calib_input_data, extract_model, get_predefined_partition_cfg, torch2onnx, torch2torchscript, visualize_model) from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.apis.utils import to_backend from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, get_ir_config, get_partition_config, get_root_logger, load_config, target_wrapper) from torch.multiprocessing import Process, set_start_method def parse_args(): parser = argparse.ArgumentParser(description='Export model to backends.') parser.add_argument('deploy_cfg', help='deploy config path') parser.add_argument('model_cfg', help='model config path') parser.add_argument('checkpoint', help='model checkpoint path') parser.add_argument('img', help='image used to convert model model') parser.add_argument( '--test-img', default=None, type=str, nargs='+', help='image used to test model') parser.add_argument( '--work-dir', default=os.getcwd(), help='the dir to save logs and models') parser.add_argument( '--calib-dataset-cfg', help=('dataset config path used to calibrate in int8 mode. If not ' 'specified, it will use "val" dataset in model config instead.'), default=None) parser.add_argument( '--device', help='device used for conversion', default='cpu') parser.add_argument( '--log-level', help='set log level', default='INFO', choices=list(logging._nameToLevel.keys())) parser.add_argument( '--show', action='store_true', help='Show detection outputs') parser.add_argument( '--dump-info', action='store_true', help='Output information for SDK') parser.add_argument( '--quant-image-dir', default=None, help='Image directory for quantize model.') parser.add_argument( '--quant', action='store_true', help='Quantize model to low bit.') parser.add_argument( '--uri', default='192.168.1.1:60000', help='Remote ipv4:port or ipv6:port for inference on edge device.') args = parser.parse_args() return args def create_process(name, target, args, kwargs, ret_value=None): logger = get_root_logger() logger.info(f'{name} start.') log_level = logger.level wrap_func = partial(target_wrapper, target, log_level, ret_value) process = Process(target=wrap_func, args=args, kwargs=kwargs) process.start() process.join() if ret_value is not None: if ret_value.value != 0: logger.error(f'{name} failed.') exit(1) else: logger.info(f'{name} success.') def torch2ir(ir_type: IR): """Return the conversion function from torch to the intermediate representation. Args: ir_type (IR): The type of the intermediate representation. """ if ir_type == IR.ONNX: return torch2onnx elif ir_type == IR.TORCHSCRIPT: return torch2torchscript else: raise KeyError(f'Unexpected IR type {ir_type}') def main(): args = parse_args() set_start_method('spawn', force=True) logger = get_root_logger() log_level = logging.getLevelName(args.log_level) logger.setLevel(log_level) pipeline_funcs = [ torch2onnx, torch2torchscript, extract_model, create_calib_input_data ] PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) deploy_cfg_path = args.deploy_cfg model_cfg_path = args.model_cfg checkpoint_path = args.checkpoint quant = args.quant quant_image_dir = args.quant_image_dir # load deploy_cfg deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) # create work_dir if not mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) if args.dump_info: export2SDK( deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path, device=args.device) ret_value = mp.Value('d', 0, lock=False) # convert to IR ir_config = get_ir_config(deploy_cfg) ir_save_file = ir_config['save_file'] ir_type = IR.get(ir_config['type']) torch2ir(ir_type)( args.img, args.work_dir, ir_save_file, deploy_cfg_path, model_cfg_path, checkpoint_path, device=args.device) # convert backend ir_files = [osp.join(args.work_dir, ir_save_file)] # partition model partition_cfgs = get_partition_config(deploy_cfg) if partition_cfgs is not None: if 'partition_cfg' in partition_cfgs: partition_cfgs = partition_cfgs.get('partition_cfg', None) else: assert 'type' in partition_cfgs partition_cfgs = get_predefined_partition_cfg( deploy_cfg, partition_cfgs['type']) origin_ir_file = ir_files[0] ir_files = [] for partition_cfg in partition_cfgs: save_file = partition_cfg['save_file'] save_path = osp.join(args.work_dir, save_file) start = partition_cfg['start'] end = partition_cfg['end'] dynamic_axes = partition_cfg.get('dynamic_axes', None) extract_model( origin_ir_file, start, end, dynamic_axes=dynamic_axes, save_file=save_path) ir_files.append(save_path) # calib data calib_filename = get_calib_filename(deploy_cfg) if calib_filename is not None: calib_path = osp.join(args.work_dir, calib_filename) create_calib_input_data( calib_path, deploy_cfg_path, model_cfg_path, checkpoint_path, dataset_cfg=args.calib_dataset_cfg, dataset_type='val', device=args.device) backend_files = ir_files # convert backend backend = get_backend(deploy_cfg) # preprocess deploy_cfg if backend == Backend.RKNN: # TODO: Add this to task_processor in the future import tempfile from mmdeploy.utils import (get_common_config, get_normalization, get_quantization_config, get_rknn_quantization) quantization_cfg = get_quantization_config(deploy_cfg) common_params = get_common_config(deploy_cfg) if get_rknn_quantization(deploy_cfg) is True: transform = get_normalization(model_cfg) common_params.update( dict( mean_values=[transform['mean']], std_values=[transform['std']])) dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name with open(dataset_file, 'w') as f: f.writelines([osp.abspath(args.img)]) quantization_cfg.setdefault('dataset', dataset_file) if backend == Backend.ASCEND: # TODO: Add this to backend manager in the future if args.dump_info: from mmdeploy.backend.ascend import update_sdk_pipeline update_sdk_pipeline(args.work_dir) # convert to backend PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) if backend == Backend.TENSORRT: PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) backend_files = to_backend( backend, ir_files, work_dir=args.work_dir, deploy_cfg=deploy_cfg, log_level=log_level, device=args.device, uri=args.uri) # ncnn quantization if backend == Backend.NCNN and quant: from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8 from onnx2ncnn_quant_table import get_table model_param_paths = backend_files[::2] model_bin_paths = backend_files[1::2] backend_files = [] for onnx_path, model_param_path, model_bin_path in zip( ir_files, model_param_paths, model_bin_paths): deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501 onnx_path, args.work_dir) create_process( 'ncnn quant table', target=get_table, args=(onnx_path, deploy_cfg, model_cfg, quant_onnx, quant_table, quant_image_dir, args.device), kwargs=dict(), ret_value=ret_value) create_process( 'ncnn_int8', target=ncnn2int8, args=(model_param_path, model_bin_path, quant_table, quant_param, quant_bin), kwargs=dict(), ret_value=ret_value) backend_files += [quant_param, quant_bin] if args.test_img is None: args.test_img = args.img extra = dict( backend=backend, output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), show_result=args.show) if backend == Backend.SNPE: extra['uri'] = args.uri # get backend inference result, try render create_process( f'visualize {backend.value} model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, args.device), kwargs=extra, ret_value=ret_value) # get pytorch model inference result, try visualize if possible create_process( 'visualize pytorch model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], args.test_img, args.device), kwargs=dict( backend=Backend.PYTORCH, output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), show_result=args.show), ret_value=ret_value) logger.info('All process success.') if __name__ == '__main__': main()