Pneumonia / app.py
DHEIVER's picture
Update app.py
5ed1788
raw
history blame
1.44 kB
import gradio as gr
import torch
from torchvision import transforms
import numpy as np
# Load the model
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet101', pretrained=True, weights='imagenet')
model.eval()
# Define class names
class_names = ["normal", "pneumonia"]
# Define predict function
def predict(img):
# Preprocess the image
img = img.convert("RGB")
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = preprocess(img)
img = img.unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = model(img)
_, predicted_idx = torch.max(outputs, 1)
confidence = torch.nn.functional.softmax(outputs, dim=1)[0]
# Format prediction results
prediction = {
class_names[0]: float(confidence[0]),
class_names[1]: float(confidence[1])
}
return prediction
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="numpy", label="Input Image"),
outputs=gr.outputs.Label(num_top_classes=2, label="Predicted Class"),
title="PneumoniaDetector πŸ‘",
description="A ResNet101 computer vision model to detect pneumonia",
article="Please add a chest X-Ray image"
)
# Launch the interface
iface.launch(debug=True)