Spaces:
Runtime error
Runtime error
File size: 5,283 Bytes
ac7cda5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import torch
import argparse
def onnx_to_trt(onnx_file, trt_file, fp16=False, more_cmd=None):
cap = torch.cuda.get_device_capability()
if cap[0] >= 8:
compatiable = "--hardware-compatibility-level=Ampere_Plus"
else:
compatiable = ""
cmd = [
"polygraphy",
"convert",
onnx_file,
"-o",
trt_file,
compatiable,
"--fp16" if fp16 else "",
f"--builder-optimization-level=5",
]
if more_cmd:
cmd = cmd + more_cmd
print(" ".join(cmd))
os.system(" ".join(cmd))
def onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16=False, plugin_file="./libgrid_sample_3d_plugin.so"):
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(logger, "")
plugin_libs = [plugin_file]
onnx_path = onnx_file
engine_path = trt_file
builder = trt.Builder(logger)
for pluginlib in plugin_libs:
builder.get_plugin_registry().load_library(pluginlib)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
res = parser.parse_from_file(onnx_path) # parse from file
if not res:
print(f"Fail parsing {onnx_path}")
for i in range(parser.num_errors): # Get error information
error = parser.get_error(i)
print(error) # Print error information
print(
f"{error.code() = }\n{error.file() = }\n{error.func() = }\n{error.line() = }\n{error.local_function_stack_size() = }"
)
print(
f"{error.local_function_stack() = }\n{error.node_name() = }\n{error.node_operator() = }\n{error.node() = }"
)
parser.clear_errors()
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
config.builder_optimization_level = 5
# Set the flag of hardware compatibility, Hardware-compatible engines are only supported on Ampere and beyond
cap = torch.cuda.get_device_capability()
if cap[0] >= 8:
compatible = True
else:
compatible = False
if compatible:
config.hardware_compatibility_level = (
trt.HardwareCompatibilityLevel.AMPERE_PLUS
)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True)
exclude_list = [
"SHAPE",
"ASSERTION",
"SHUFFLE",
"IDENTITY",
"CONSTANT",
"CONCATENATION",
"GATHER",
"SLICE",
"CONDITION",
"CONDITIONAL_INPUT",
"CONDITIONAL_OUTPUT",
"FILL",
"NON_ZERO",
"ONE_HOT",
]
for i in range(0, network.num_layers):
layer = network.get_layer(i)
if str(layer.type)[10:] in exclude_list:
continue
if "GridSample" in layer.name:
print(f"set {layer.name} to float32")
layer.precision = trt.float32
config.plugins_to_serialize = plugin_libs
engineString = builder.build_serialized_network(network, config)
if engineString is not None:
with open(engine_path, "wb") as f:
f.write(engineString)
def main(onnx_dir, trt_dir, grid_sample_plugin_file=""):
names = [i[:-5] for i in os.listdir(onnx_dir) if i.endswith(".onnx")]
for name in names:
if name == "warp_network_ori":
continue
print("=" * 20, f"{name} start", "=" * 20)
fp16 = False if name in {"motion_extractor", "hubert", "wavlm"} or name.startswith("lmdm") else True
more_cmd = None
if name == "wavlm":
more_cmd = [
"--trt-min-shapes audio:[1,1000]",
"--trt-max-shapes audio:[1,320080]",
"--trt-opt-shapes audio:[1,320080]",
]
elif name == "hubert":
more_cmd = [
"--trt-min-shapes input_values:[1,3240]",
"--trt-max-shapes input_values:[1,12960]",
"--trt-opt-shapes input_values:[1,6480]",
]
onnx_file = f"{onnx_dir}/{name}.onnx"
trt_file = f"{trt_dir}/{name}_fp{16 if fp16 else 32}.engine"
if os.path.isfile(trt_file):
print("=" * 20, f"{name} skip", "=" * 20)
continue
if name == "warp_network":
onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16, plugin_file=grid_sample_plugin_file)
else:
onnx_to_trt(onnx_file, trt_file, fp16, more_cmd=more_cmd)
print("=" * 20, f"{name} done", "=" * 20)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_dir", type=str, help="input onnx dir")
parser.add_argument("--trt_dir", type=str, help="output trt dir")
args = parser.parse_args()
onnx_dir = args.onnx_dir
trt_dir = args.trt_dir
assert os.path.isdir(onnx_dir)
os.makedirs(trt_dir, exist_ok=True)
grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so")
main(onnx_dir, trt_dir, grid_sample_plugin_file)
|