Spaces:
Sleeping
Sleeping
import torch | |
import sys | |
import numpy as np | |
from utils.model import AttentionUNet | |
import onnx | |
import onnxruntime as ort | |
def convert_to_onnx(pytorch_model_path, onnx_output_path, input_size=256): | |
""" | |
Convert a PyTorch model to ONNX format | |
Args: | |
pytorch_model_path: Path to the PyTorch model | |
onnx_output_path: Path to save the ONNX model | |
input_size: Input size for the model (default is 256x256) | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Device used for conversion: {device}") | |
model = AttentionUNet(in_channels=3, out_channels=1) | |
model.to(device) | |
model.load_state_dict(torch.load(pytorch_model_path, map_location=device)) | |
model.eval() | |
# Create dummy input | |
dummy_input = torch.randn(1, 3, input_size, input_size, device=device) | |
# Export the model | |
torch.onnx.export( | |
model, # model being run | |
dummy_input, # model input (or a tuple for multiple inputs) | |
onnx_output_path, # where to save the model | |
export_params=True, # store the trained parameter weights inside the model file | |
opset_version=12, # the ONNX version to export the model to | |
do_constant_folding=True, # whether to execute constant folding for optimization | |
input_names=["input"], | |
output_names=["output"], | |
dynamic_axes={ | |
"input": {0: "batch_size", 2: "height", 3: "width"}, # variable length axes | |
"output": {0: "batch_size", 2: "height", 3: "width"}, | |
}, | |
) | |
print(f"Model converted and saved to {onnx_output_path}") | |
verify_onnx_model(onnx_output_path, input_size) | |
def verify_onnx_model(onnx_model_path, input_size=256): | |
""" | |
Verify the ONNX model to ensure it was exported correctly | |
Args: | |
onnx_model_path: Path to the ONNX model | |
input_size: Input size used during export | |
""" | |
try: | |
onnx_model = onnx.load(onnx_model_path) | |
onnx.checker.check_model(onnx_model) | |
print("ONNX model is valid") | |
except Exception as e: | |
print(f"ONNX model validation failed: {e}") | |
return False | |
try: | |
session = ort.InferenceSession( | |
onnx_model_path, providers=["CPUExecutionProvider"] | |
) | |
input_data = np.random.rand(1, 3, input_size, input_size).astype(np.float32) | |
# Get input and output names | |
input_name = session.get_inputs()[0].name | |
output_name = session.get_outputs()[0].name | |
# Run inference | |
outputs = session.run([output_name], {input_name: input_data}) | |
print(f"ONNX model inference test passed. Output shape: {outputs[0].shape}") | |
return True | |
except Exception as e: | |
print(f"ONNX model inference test failed: {e}") | |
return False | |
if __name__ == "__main__": | |
if len(sys.argv) < 3: | |
print( | |
"Usage: python -m utils.onnx_converter <pytorch_model_path> <onnx_output_path> [input_size]" | |
) | |
sys.exit(1) | |
pytorch_model_path = sys.argv[1] | |
onnx_output_path = sys.argv[2] | |
input_size = int(sys.argv[3]) if len(sys.argv) > 3 else 256 | |
convert_to_onnx(pytorch_model_path, onnx_output_path, input_size) | |