devadethanr's picture
Update app.py
9865f77 verified
raw
history blame
1.65 kB
# import gradio as gr
# gr.load("models/devadethanr/alz_model").launch()
import gradio as gr
from transformers import AutoModelForImageClassification
import torch
import numpy as np
# Load the model and image processor from the Hub
model_name = "devadethanr/alz_model"
model = AutoModelForImageClassification.from_pretrained(model_name)
# Get the label names from the model's configuration
labels = model.config.id2label
# Define the prediction function (with preprocessing)
def predict_image(image):
"""
Predicts the Alzheimer's disease stage from an uploaded MRI image.
Args:
image: The uploaded MRI image (PIL Image).
Returns:
The predicted label with its corresponding probability.
"""
image = model.preprocess_image(image, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**image).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label = labels[predicted_label_id]
# Calculate probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)
confidences = {label: float(probabilities[0][i]) for i, label in enumerate(labels)}
return predicted_label, confidences
# Create the Gradio interface (same as before)
iface = gr.Interface(
fn=predict_image,
inputs=gr.inputs.Image(type="pil", label="Upload MRI Image"),
outputs=[
gr.outputs.Label(label="Prediction"),
gr.outputs.JSON(label="Confidence Scores")
],
title="Alzheimer's Disease MRI Image Classifier",
description="Upload an MRI image to predict the stage of Alzheimer's disease."
)
iface.launch()