devadethanr's picture
Update app.py
d8df7f2 verified
raw
history blame
1.82 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import numpy as np
from PIL import Image
# Load the model from the Hub
model_name = "devadethanr/alz_model"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Get the label names from the model's configuration
labels = model.config.id2label
# Define the prediction function (with preprocessing)
def predict_image(image):
"""
Predicts the Alzheimer's disease stage from an uploaded MRI image.
Args:
image: The uploaded MRI image (PIL Image).
Returns:
The predicted label with its corresponding probability.
"""
# Preprocessing steps:
image = np.array(image)
image = np.repeat(image[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
# Model inference:
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label = labels[str(predicted_label_id)]
# Calculate probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)
confidences = {labels[str(i)]: float(probabilities[0][i]) for i in range(len(labels))}
return predicted_label, confidences
# Create the Gradio interface (updated)
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload MRI Image"), # Use gr.Image directly
outputs=[
gr.Label(label="Prediction"),
gr.JSON(label="Confidence Scores")
],
title="Alzheimer's Disease MRI Image Classifier",
description="Upload an MRI image to predict the stage of Alzheimer's disease."
)
iface.launch()