Spaces:
Running
Running
| import os | |
| import onnxruntime | |
| class ONNXEngine: | |
| def __init__(self, onnx_path, use_gpu): | |
| """ | |
| :param onnx_path: | |
| """ | |
| if not os.path.exists(onnx_path): | |
| raise Exception(f'{onnx_path} is not exists') | |
| providers = ['CPUExecutionProvider'] | |
| if use_gpu: | |
| providers = ([ | |
| 'TensorrtExecutionProvider', | |
| 'CUDAExecutionProvider', | |
| 'CPUExecutionProvider', | |
| ], ) | |
| self.onnx_session = onnxruntime.InferenceSession(onnx_path, | |
| providers=providers) | |
| self.input_name = self.get_input_name(self.onnx_session) | |
| self.output_name = self.get_output_name(self.onnx_session) | |
| def get_output_name(self, onnx_session): | |
| """ | |
| output_name = onnx_session.get_outputs()[0].name | |
| :param onnx_session: | |
| :return: | |
| """ | |
| output_name = [] | |
| for node in onnx_session.get_outputs(): | |
| output_name.append(node.name) | |
| return output_name | |
| def get_input_name(self, onnx_session): | |
| """ | |
| input_name = onnx_session.get_inputs()[0].name | |
| :param onnx_session: | |
| :return: | |
| """ | |
| input_name = [] | |
| for node in onnx_session.get_inputs(): | |
| input_name.append(node.name) | |
| return input_name | |
| def get_input_feed(self, input_name, image_numpy): | |
| """ | |
| input_feed={self.input_name: image_numpy} | |
| :param input_name: | |
| :param image_numpy: | |
| :return: | |
| """ | |
| input_feed = {} | |
| for name in input_name: | |
| input_feed[name] = image_numpy | |
| return input_feed | |
| def run(self, image_numpy): | |
| # 输入数据的类型必须与模型一致,以下三种写法都是可以的 | |
| input_feed = self.get_input_feed(self.input_name, image_numpy) | |
| result = self.onnx_session.run(self.output_name, input_feed=input_feed) | |
| return result | |