Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import argparse | |
def onnx_to_trt(onnx_file, trt_file, fp16=False, more_cmd=None): | |
cap = torch.cuda.get_device_capability() | |
if cap[0] >= 8: | |
compatiable = "--hardware-compatibility-level=Ampere_Plus" | |
else: | |
compatiable = "" | |
cmd = [ | |
"polygraphy", | |
"convert", | |
onnx_file, | |
"-o", | |
trt_file, | |
compatiable, | |
"--fp16" if fp16 else "", | |
f"--builder-optimization-level=5", | |
] | |
if more_cmd: | |
cmd = cmd + more_cmd | |
print(" ".join(cmd)) | |
os.system(" ".join(cmd)) | |
def onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16=False, plugin_file="./libgrid_sample_3d_plugin.so"): | |
import tensorrt as trt | |
logger = trt.Logger(trt.Logger.INFO) | |
trt.init_libnvinfer_plugins(logger, "") | |
plugin_libs = [plugin_file] | |
onnx_path = onnx_file | |
engine_path = trt_file | |
builder = trt.Builder(logger) | |
for pluginlib in plugin_libs: | |
builder.get_plugin_registry().load_library(pluginlib) | |
network = builder.create_network( | |
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | |
) | |
parser = trt.OnnxParser(network, logger) | |
res = parser.parse_from_file(onnx_path) # parse from file | |
if not res: | |
print(f"Fail parsing {onnx_path}") | |
for i in range(parser.num_errors): # Get error information | |
error = parser.get_error(i) | |
print(error) # Print error information | |
print( | |
f"{error.code() = }\n{error.file() = }\n{error.func() = }\n{error.line() = }\n{error.local_function_stack_size() = }" | |
) | |
print( | |
f"{error.local_function_stack() = }\n{error.node_name() = }\n{error.node_operator() = }\n{error.node() = }" | |
) | |
parser.clear_errors() | |
config = builder.create_builder_config() | |
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) | |
config.builder_optimization_level = 5 | |
# Set the flag of hardware compatibility, Hardware-compatible engines are only supported on Ampere and beyond | |
cap = torch.cuda.get_device_capability() | |
if cap[0] >= 8: | |
compatible = True | |
else: | |
compatible = False | |
if compatible: | |
config.hardware_compatibility_level = ( | |
trt.HardwareCompatibilityLevel.AMPERE_PLUS | |
) | |
if fp16: | |
config.set_flag(trt.BuilderFlag.FP16) | |
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS) | |
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True) | |
exclude_list = [ | |
"SHAPE", | |
"ASSERTION", | |
"SHUFFLE", | |
"IDENTITY", | |
"CONSTANT", | |
"CONCATENATION", | |
"GATHER", | |
"SLICE", | |
"CONDITION", | |
"CONDITIONAL_INPUT", | |
"CONDITIONAL_OUTPUT", | |
"FILL", | |
"NON_ZERO", | |
"ONE_HOT", | |
] | |
for i in range(0, network.num_layers): | |
layer = network.get_layer(i) | |
if str(layer.type)[10:] in exclude_list: | |
continue | |
if "GridSample" in layer.name: | |
print(f"set {layer.name} to float32") | |
layer.precision = trt.float32 | |
config.plugins_to_serialize = plugin_libs | |
engineString = builder.build_serialized_network(network, config) | |
if engineString is not None: | |
with open(engine_path, "wb") as f: | |
f.write(engineString) | |
def main(onnx_dir, trt_dir, grid_sample_plugin_file=""): | |
names = [i[:-5] for i in os.listdir(onnx_dir) if i.endswith(".onnx")] | |
for name in names: | |
if name == "warp_network_ori": | |
continue | |
print("=" * 20, f"{name} start", "=" * 20) | |
fp16 = False if name in {"motion_extractor", "hubert", "wavlm"} or name.startswith("lmdm") else True | |
more_cmd = None | |
if name == "wavlm": | |
more_cmd = [ | |
"--trt-min-shapes audio:[1,1000]", | |
"--trt-max-shapes audio:[1,320080]", | |
"--trt-opt-shapes audio:[1,320080]", | |
] | |
elif name == "hubert": | |
more_cmd = [ | |
"--trt-min-shapes input_values:[1,3240]", | |
"--trt-max-shapes input_values:[1,12960]", | |
"--trt-opt-shapes input_values:[1,6480]", | |
] | |
onnx_file = f"{onnx_dir}/{name}.onnx" | |
trt_file = f"{trt_dir}/{name}_fp{16 if fp16 else 32}.engine" | |
if os.path.isfile(trt_file): | |
print("=" * 20, f"{name} skip", "=" * 20) | |
continue | |
if name == "warp_network": | |
onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16, plugin_file=grid_sample_plugin_file) | |
else: | |
onnx_to_trt(onnx_file, trt_file, fp16, more_cmd=more_cmd) | |
print("=" * 20, f"{name} done", "=" * 20) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--onnx_dir", type=str, help="input onnx dir") | |
parser.add_argument("--trt_dir", type=str, help="output trt dir") | |
args = parser.parse_args() | |
onnx_dir = args.onnx_dir | |
trt_dir = args.trt_dir | |
assert os.path.isdir(onnx_dir) | |
os.makedirs(trt_dir, exist_ok=True) | |
grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so") | |
main(onnx_dir, trt_dir, grid_sample_plugin_file) | |