# Copyright (c) OpenMMLab. All rights reserved. import argparse from functools import partial import mmcv_custom # noqa: F401,F403 import mmseg_custom # noqa: F401,F403 import numpy as np import torch from mmcv import Config, DictAction from mmseg.models import build_segmentor try: from mmcv.cnn import get_model_complexity_info from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string except ImportError: raise ImportError('Please upgrade mmcv to >0.6.2') def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[512, 2048], help='input image size') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--size-divisor', type=int, default=32, help='Pad the input image, the minimum size that is divisible ' 'by size_divisor, -1 means do not pad the image.') args = parser.parse_args() return args def dcnv3_flops(n, k, c): return 5 * n * k * c def get_flops(model, input_shape): flops, params = get_model_complexity_info(model, input_shape, as_strings=False) backbone = model.backbone backbone_name = type(backbone).__name__ _, H, W = input_shape temp = 0 if 'InternImage' in backbone_name: depths = backbone.depths # [4, 4, 18, 4] for idx, depth in enumerate(depths): channels = backbone.channels * (2 ** idx) h = H / (4 * (2 ** idx)) w = W / (4 * (2 ** idx)) temp += depth * dcnv3_flops(n=h*w, k=3*3, c=channels) flops = flops + temp return flops_to_string(flops), params_to_string(params) if __name__ == '__main__': args = parse_args() if len(args.shape) == 1: h = w = args.shape[0] elif len(args.shape) == 2: h, w = args.shape else: raise ValueError('invalid input shape') orig_shape = (3, h, w) divisor = args.size_divisor if divisor > 0: h = int(np.ceil(h / divisor)) * divisor w = int(np.ceil(w / divisor)) * divisor input_shape = (3, h, w) cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) model = build_segmentor( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) if torch.cuda.is_available(): model.cuda() model.eval() if hasattr(model, 'forward_dummy'): model.forward = model.forward_dummy else: raise NotImplementedError( 'FLOPs counter is currently not currently supported with {}'. format(model.__class__.__name__)) flops, params = get_flops(model, input_shape) split_line = '=' * 30 if divisor > 0 and \ input_shape != orig_shape: print(f'{split_line}\nUse size divisor set input shape ' f'from {orig_shape} to {input_shape}\n') print(f'{split_line}\nInput shape: {input_shape}\n' f'Flops: {flops}\nParams: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.')