|
import os |
|
from typing import Callable |
|
import torch |
|
import torch.nn as nn |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Callable, Tuple, Union, Tuple, Union, Any |
|
|
|
HOME = os.environ["HOME"].rstrip("/") |
|
print('HOME', HOME) |
|
|
|
|
|
def import_abspy(name="models", path="classification/"): |
|
import sys |
|
import importlib |
|
path = os.path.abspath(path) |
|
print(path) |
|
assert os.path.isdir(path) |
|
sys.path.insert(0, path) |
|
module = importlib.import_module(name) |
|
sys.path.pop(0) |
|
return module |
|
|
|
|
|
def print_jit_input_names(inputs): |
|
print("input params: ", end=" ", flush=True) |
|
try: |
|
for i in range(10): |
|
print(inputs[i].debugName(), end=" ", flush=True) |
|
except Exception as e: |
|
pass |
|
print("", flush=True) |
|
|
|
|
|
|
|
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): |
|
""" |
|
u: r(B D L) |
|
delta: r(B D L) |
|
A: r(D N) |
|
B: r(B N L) |
|
C: r(B N L) |
|
D: r(D) |
|
z: r(B D L) |
|
delta_bias: r(D), fp32 |
|
|
|
ignores: |
|
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] |
|
""" |
|
import numpy as np |
|
|
|
|
|
def get_flops_einsum(input_shapes, equation): |
|
np_arrs = [np.zeros(s) for s in input_shapes] |
|
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] |
|
for line in optim.split("\n"): |
|
if "optimized flop" in line.lower(): |
|
|
|
flop = float(np.floor(float(line.split(":")[-1]) / 2)) |
|
return flop |
|
|
|
assert not with_complex |
|
|
|
flops = 0 |
|
|
|
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") |
|
if with_Group: |
|
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") |
|
else: |
|
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") |
|
|
|
in_for_flops = B * D * N |
|
if with_Group: |
|
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") |
|
else: |
|
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") |
|
flops += L * in_for_flops |
|
if with_D: |
|
flops += B * D * L |
|
if with_Z: |
|
flops += B * D * L |
|
return flops |
|
|
|
|
|
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): |
|
assert not with_complex |
|
|
|
flops = 9 * B * L * D * N |
|
if with_D: |
|
flops += B * D * L |
|
if with_Z: |
|
flops += B * D * L |
|
return flops |
|
|
|
|
|
def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True): |
|
if verbose: |
|
print_jit_input_names(inputs) |
|
flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn |
|
B, D, L = inputs[0].type().sizes() |
|
N = inputs[2].type().sizes()[1] |
|
flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) |
|
return flops |
|
|
|
|
|
|
|
class FLOPs: |
|
@staticmethod |
|
def register_supported_ops(): |
|
|
|
|
|
|
|
|
|
|
|
supported_ops = { |
|
"aten::gelu": None, |
|
"aten::silu": None, |
|
"aten::neg": None, |
|
"aten::exp": None, |
|
"aten::flip": None, |
|
"prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, |
|
"prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit, |
|
"prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit, |
|
"prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit, |
|
"prim::PythonOp.SelectiveScan": selective_scan_flop_jit, |
|
"prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit, |
|
|
|
} |
|
return supported_ops |
|
|
|
@staticmethod |
|
def check_operations(model: nn.Module, inputs=None, input_shape=(3, 224, 224)): |
|
from fvcore.nn.jit_analysis import _get_scoped_trace_graph, _named_modules_with_dup, Counter, JitModelAnalysis |
|
|
|
if inputs is None: |
|
assert input_shape is not None |
|
if len(input_shape) == 1: |
|
input_shape = (1, 3, input_shape[0], input_shape[0]) |
|
elif len(input_shape) == 2: |
|
input_shape = (1, 3, *input_shape) |
|
elif len(input_shape) == 3: |
|
input_shape = (1, *input_shape) |
|
else: |
|
assert len(input_shape) == 4 |
|
|
|
inputs = (torch.randn(input_shape).to(next(model.parameters()).device),) |
|
|
|
model.eval() |
|
|
|
flop_counter = JitModelAnalysis(model, inputs) |
|
flop_counter._ignored_ops = set() |
|
flop_counter._op_handles = dict() |
|
assert flop_counter.total() == 0 |
|
print(flop_counter.unsupported_ops(), flush=True) |
|
print(f"supported ops {flop_counter._op_handles}; ignore ops {flop_counter._ignored_ops};", flush=True) |
|
|
|
@classmethod |
|
def fvcore_flop_count(cls, model: nn.Module, inputs=None, input_shape=(3, 224, 224), show_table=False, |
|
show_arch=False, verbose=True): |
|
supported_ops = cls.register_supported_ops() |
|
from fvcore.nn.parameter_count import parameter_count as fvcore_parameter_count |
|
from fvcore.nn.flop_count import flop_count, FlopCountAnalysis, _DEFAULT_SUPPORTED_OPS |
|
from fvcore.nn.print_model_statistics import flop_count_str, flop_count_table |
|
from fvcore.nn.jit_analysis import _IGNORED_OPS |
|
from fvcore.nn.jit_handles import get_shape, addmm_flop_jit |
|
|
|
if inputs is None: |
|
assert input_shape is not None |
|
if len(input_shape) == 1: |
|
input_shape = (1, 3, input_shape[0], input_shape[0]) |
|
elif len(input_shape) == 2: |
|
input_shape = (1, 3, *input_shape) |
|
elif len(input_shape) == 3: |
|
input_shape = (1, *input_shape) |
|
else: |
|
assert len(input_shape) == 4 |
|
|
|
inputs = (torch.randn(input_shape).to(next(model.parameters()).device),) |
|
|
|
model.eval() |
|
|
|
print("model Prepared") |
|
|
|
Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops) |
|
|
|
print("flop_count Done") |
|
|
|
flops_table = flop_count_table( |
|
flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops), |
|
max_depth=100, |
|
activations=None, |
|
show_param_shapes=True, |
|
) |
|
|
|
flops_str = flop_count_str( |
|
flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops), |
|
activations=None, |
|
) |
|
|
|
if show_arch: |
|
print(flops_str) |
|
|
|
if show_table: |
|
print(flops_table) |
|
|
|
params = fvcore_parameter_count(model)[""] |
|
flops = sum(Gflops.values()) |
|
|
|
if verbose: |
|
print(Gflops.items()) |
|
print("GFlops: ", flops, "Params: ", params, flush=True) |
|
|
|
return params, flops |
|
|
|
|
|
@classmethod |
|
def mmengine_flop_count(cls, model: nn.Module = None, input_shape=(3, 224, 224), show_table=False, show_arch=False, |
|
_get_model_complexity_info=False): |
|
supported_ops = cls.register_supported_ops() |
|
from mmengine.analysis.print_helper import is_tuple_of, FlopAnalyzer, ActivationAnalyzer, parameter_count, \ |
|
_format_size, complexity_stats_table, complexity_stats_str |
|
from mmengine.analysis.jit_analysis import _IGNORED_OPS |
|
from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS, _DEFAULT_SUPPORTED_ACT_OPS |
|
from mmengine.analysis import get_model_complexity_info as mm_get_model_complexity_info |
|
|
|
|
|
def get_model_complexity_info( |
|
model: nn.Module, |
|
input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], |
|
None] = None, |
|
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...], |
|
None] = None, |
|
show_table: bool = True, |
|
show_arch: bool = True, |
|
): |
|
if input_shape is None and inputs is None: |
|
raise ValueError('One of "input_shape" and "inputs" should be set.') |
|
elif input_shape is not None and inputs is not None: |
|
raise ValueError('"input_shape" and "inputs" cannot be both set.') |
|
|
|
if inputs is None: |
|
device = next(model.parameters()).device |
|
if is_tuple_of(input_shape, int): |
|
inputs = (torch.randn(1, *input_shape).to(device),) |
|
elif is_tuple_of(input_shape, tuple) and all([ |
|
is_tuple_of(one_input_shape, int) |
|
for one_input_shape in input_shape |
|
]): |
|
inputs = tuple([ |
|
torch.randn(1, *one_input_shape).to(device) |
|
for one_input_shape in input_shape |
|
]) |
|
else: |
|
raise ValueError( |
|
'"input_shape" should be either a `tuple of int` (to construct' |
|
'one input tensor) or a `tuple of tuple of int` (to construct' |
|
'multiple input tensors).') |
|
|
|
flop_handler = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops) |
|
|
|
|
|
flops = flop_handler.total() |
|
|
|
params = parameter_count(model)[''] |
|
|
|
flops_str = _format_size(flops) |
|
|
|
params_str = _format_size(params) |
|
|
|
if show_table: |
|
complexity_table = complexity_stats_table( |
|
flops=flop_handler, |
|
|
|
show_param_shapes=True, |
|
) |
|
complexity_table = '\n' + complexity_table |
|
else: |
|
complexity_table = '' |
|
|
|
if show_arch: |
|
complexity_arch = complexity_stats_str( |
|
flops=flop_handler, |
|
|
|
) |
|
complexity_arch = '\n' + complexity_arch |
|
else: |
|
complexity_arch = '' |
|
|
|
return { |
|
'flops': flops, |
|
'flops_str': flops_str, |
|
|
|
|
|
'params': params, |
|
'params_str': params_str, |
|
'out_table': complexity_table, |
|
'out_arch': complexity_arch |
|
} |
|
|
|
if _get_model_complexity_info: |
|
return get_model_complexity_info |
|
|
|
model.eval() |
|
analysis_results = get_model_complexity_info( |
|
model, |
|
input_shape, |
|
show_table=show_table, |
|
show_arch=show_arch, |
|
) |
|
flops = analysis_results['flops_str'] |
|
params = analysis_results['params_str'] |
|
|
|
out_table = analysis_results['out_table'] |
|
out_arch = analysis_results['out_arch'] |
|
|
|
if show_arch: |
|
print(out_arch) |
|
|
|
if show_table: |
|
print(out_table) |
|
|
|
split_line = '=' * 30 |
|
print(f'{split_line}\nInput shape: {input_shape}\t' |
|
f'Flops: {flops}\tParams: {params}\t' |
|
|
|
, flush=True) |
|
|
|
|
|
|
|
|