File size: 2,205 Bytes
18841bc d8df7f2 9865f77 d8df7f2 9865f77 d8df7f2 767a524 9865f77 d8df7f2 9865f77 4ecb9f3 9865f77 d8df7f2 4ecb9f3 9865f77 d8df7f2 9865f77 d8df7f2 9865f77 4ecb9f3 9865f77 4ecb9f3 9865f77 4ecb9f3 9865f77 4ecb9f3 9865f77 4ecb9f3 9865f77 d8df7f2 9865f77 4ecb9f3 9865f77 18841bc 9865f77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import numpy as np
from PIL import Image
# Load the model from the Hub
model_name = "devadethanr/alz_model"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Get the label names from the model's configuration
labels = model.config.id2label
print("Labels:", labels) # Debugging statement to check the labels
# Define the prediction function (with preprocessing)
def predict_image(image):
"""
Predicts the Alzheimer's disease stage from an uploaded MRI image.
Args:
image: The uploaded MRI image (PIL Image).
Returns:
The predicted label with its corresponding probability.
"""
# Preprocessing steps:
image = np.array(image)
if image.ndim == 2: # Convert grayscale to RGB if needed
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
# Resize image if necessary (optional step)
if image.shape[0] != 224 or image.shape[1] != 224:
image = np.array(Image.fromarray(image).resize((224, 224)))
# Model inference:
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
print(f"logits shape: {logits.shape}") # Debugging statement to check shape
print(f"logits: {logits}") # Debugging statement to check content
predicted_label_id = logits.argmax(-1).item()
predicted_label = labels[predicted_label_id]
# Calculate probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)
confidences = {labels[i]: float(probabilities[0][i]) for i in range(len(labels))}
return predicted_label, confidences
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload MRI Image"),
outputs=[
gr.Label(label="Prediction"),
gr.JSON(label="Confidence Scores")
],
title="Alzheimer's Disease MRI Image Classifier",
description="Upload an MRI image to predict the stage of Alzheimer's disease."
)
iface.launch()
|