File size: 1,652 Bytes
9865f77 18841bc 9865f77 18841bc 9865f77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# import gradio as gr
# gr.load("models/devadethanr/alz_model").launch()
import gradio as gr
from transformers import AutoModelForImageClassification
import torch
import numpy as np
# Load the model and image processor from the Hub
model_name = "devadethanr/alz_model"
model = AutoModelForImageClassification.from_pretrained(model_name)
# Get the label names from the model's configuration
labels = model.config.id2label
# Define the prediction function (with preprocessing)
def predict_image(image):
"""
Predicts the Alzheimer's disease stage from an uploaded MRI image.
Args:
image: The uploaded MRI image (PIL Image).
Returns:
The predicted label with its corresponding probability.
"""
image = model.preprocess_image(image, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**image).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label = labels[predicted_label_id]
# Calculate probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)
confidences = {label: float(probabilities[0][i]) for i, label in enumerate(labels)}
return predicted_label, confidences
# Create the Gradio interface (same as before)
iface = gr.Interface(
fn=predict_image,
inputs=gr.inputs.Image(type="pil", label="Upload MRI Image"),
outputs=[
gr.outputs.Label(label="Prediction"),
gr.outputs.JSON(label="Confidence Scores")
],
title="Alzheimer's Disease MRI Image Classifier",
description="Upload an MRI image to predict the stage of Alzheimer's disease."
)
iface.launch()
|