|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
import numpy as np |
|
|
|
|
|
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet101', pretrained=True, weights='imagenet') |
|
model.eval() |
|
|
|
|
|
class_names = ["normal", "pneumonia"] |
|
|
|
|
|
def predict(img): |
|
|
|
img = img.convert("RGB") |
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
img = preprocess(img) |
|
img = img.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(img) |
|
_, predicted_idx = torch.max(outputs, 1) |
|
confidence = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
|
|
|
|
prediction = { |
|
class_names[0]: float(confidence[0]), |
|
class_names[1]: float(confidence[1]) |
|
} |
|
|
|
return prediction |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.inputs.Image(type="numpy", label="Input Image"), |
|
outputs=gr.outputs.Label(num_top_classes=2, label="Predicted Class"), |
|
title="PneumoniaDetector π", |
|
description="A ResNet101 computer vision model to detect pneumonia", |
|
article="Please add a chest X-Ray image" |
|
) |
|
|
|
|
|
iface.launch(debug=True) |
|
|