import numpy as np | |
import onnxruntime as ort | |
import gradio as gr | |
from PIL import Image | |
from torchvision.models import ResNet50_Weights | |
weights = ResNet50_Weights.DEFAULT | |
preprocess = weights.transforms() # Necessary input transformations | |
ort_session = ort.InferenceSession("resnet50.onnx", providers=["CPUExecutionProvider"]) | |
def preprocess_inputs(img: Image): | |
img = preprocess(img) | |
img_array = np.array(img).astype(np.float32) | |
img_array = np.expand_dims(img_array, axis=0) | |
return img_array | |
def predict(img): | |
img = preprocess_inputs(img) | |
ort_inputs = {ort_session.get_inputs()[0].name: img} | |
ort_outputs = ort_session.run(None, ort_inputs) | |
label_index = np.argmax(ort_outputs[0], axis=1).item() | |
predicted_label = weights.meta["categories"][label_index] | |
return predicted_label | |
with gr.Blocks() as demo: | |
gr.Markdown("# ResNet-50 Using ONNX Runtime") | |
gr.Markdown("Upload any image and see if ResNet-50 can classify it! (1000 possible image classes)") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", image_mode="RGB", label="Input Image") | |
label_output = gr.Label(label="Predicted Label") | |
gr.Markdown("Part of a tutorial on [how to deploy an ONNX mode to Hugging Face](https://liamgroen.nl/posts/day-6-deploying-model-to-huggingface-spaces-through-onnx/index.html)") | |
demo.launch() |