import os import torch import argparse def onnx_to_trt(onnx_file, trt_file, fp16=False, more_cmd=None): cap = torch.cuda.get_device_capability() if cap[0] >= 8: compatiable = "--hardware-compatibility-level=Ampere_Plus" else: compatiable = "" cmd = [ "polygraphy", "convert", onnx_file, "-o", trt_file, compatiable, "--fp16" if fp16 else "", f"--builder-optimization-level=5", ] if more_cmd: cmd = cmd + more_cmd print(" ".join(cmd)) os.system(" ".join(cmd)) def onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16=False, plugin_file="./libgrid_sample_3d_plugin.so"): import tensorrt as trt logger = trt.Logger(trt.Logger.INFO) trt.init_libnvinfer_plugins(logger, "") plugin_libs = [plugin_file] onnx_path = onnx_file engine_path = trt_file builder = trt.Builder(logger) for pluginlib in plugin_libs: builder.get_plugin_registry().load_library(pluginlib) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger) res = parser.parse_from_file(onnx_path) # parse from file if not res: print(f"Fail parsing {onnx_path}") for i in range(parser.num_errors): # Get error information error = parser.get_error(i) print(error) # Print error information print( f"{error.code() = }\n{error.file() = }\n{error.func() = }\n{error.line() = }\n{error.local_function_stack_size() = }" ) print( f"{error.local_function_stack() = }\n{error.node_name() = }\n{error.node_operator() = }\n{error.node() = }" ) parser.clear_errors() config = builder.create_builder_config() # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) config.builder_optimization_level = 5 # Set the flag of hardware compatibility, Hardware-compatible engines are only supported on Ampere and beyond cap = torch.cuda.get_device_capability() if cap[0] >= 8: compatible = True else: compatible = False if compatible: config.hardware_compatibility_level = ( trt.HardwareCompatibilityLevel.AMPERE_PLUS ) if fp16: config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS) config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True) exclude_list = [ "SHAPE", "ASSERTION", "SHUFFLE", "IDENTITY", "CONSTANT", "CONCATENATION", "GATHER", "SLICE", "CONDITION", "CONDITIONAL_INPUT", "CONDITIONAL_OUTPUT", "FILL", "NON_ZERO", "ONE_HOT", ] for i in range(0, network.num_layers): layer = network.get_layer(i) if str(layer.type)[10:] in exclude_list: continue if "GridSample" in layer.name: print(f"set {layer.name} to float32") layer.precision = trt.float32 config.plugins_to_serialize = plugin_libs engineString = builder.build_serialized_network(network, config) if engineString is not None: with open(engine_path, "wb") as f: f.write(engineString) def main(onnx_dir, trt_dir, grid_sample_plugin_file=""): names = [i[:-5] for i in os.listdir(onnx_dir) if i.endswith(".onnx")] for name in names: if name == "warp_network_ori": continue print("=" * 20, f"{name} start", "=" * 20) fp16 = False if name in {"motion_extractor", "hubert", "wavlm"} or name.startswith("lmdm") else True more_cmd = None if name == "wavlm": more_cmd = [ "--trt-min-shapes audio:[1,1000]", "--trt-max-shapes audio:[1,320080]", "--trt-opt-shapes audio:[1,320080]", ] elif name == "hubert": more_cmd = [ "--trt-min-shapes input_values:[1,3240]", "--trt-max-shapes input_values:[1,12960]", "--trt-opt-shapes input_values:[1,6480]", ] onnx_file = f"{onnx_dir}/{name}.onnx" trt_file = f"{trt_dir}/{name}_fp{16 if fp16 else 32}.engine" if os.path.isfile(trt_file): print("=" * 20, f"{name} skip", "=" * 20) continue if name == "warp_network": onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16, plugin_file=grid_sample_plugin_file) else: onnx_to_trt(onnx_file, trt_file, fp16, more_cmd=more_cmd) print("=" * 20, f"{name} done", "=" * 20) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--onnx_dir", type=str, help="input onnx dir") parser.add_argument("--trt_dir", type=str, help="output trt dir") args = parser.parse_args() onnx_dir = args.onnx_dir trt_dir = args.trt_dir assert os.path.isdir(onnx_dir) os.makedirs(trt_dir, exist_ok=True) grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so") main(onnx_dir, trt_dir, grid_sample_plugin_file)