import gradio as gr from transformers import AutoModelForImageClassification, AutoImageProcessor import torch import numpy as np from PIL import Image # Load the model from the Hub model_name = "devadethanr/alz_model" model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) # Get the label names from the model's configuration labels = model.config.id2label # Define the prediction function (with preprocessing) def predict_image(image): """ Predicts the Alzheimer's disease stage from an uploaded MRI image. Args: image: The uploaded MRI image (PIL Image). Returns: The predicted label with its corresponding probability. """ # Preprocessing steps: image = np.array(image) image = np.repeat(image[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB # Model inference: inputs = processor(images=image, return_tensors="pt").to(model.device) with torch.no_grad(): logits = model(**inputs).logits predicted_label_id = logits.argmax(-1).item() predicted_label = labels[str(predicted_label_id)] # Calculate probabilities using softmax probabilities = torch.nn.functional.softmax(logits, dim=-1) confidences = {labels[str(i)]: float(probabilities[0][i]) for i in range(len(labels))} return predicted_label, confidences # Create the Gradio interface (updated) iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Upload MRI Image"), # Use gr.Image directly outputs=[ gr.Label(label="Prediction"), gr.JSON(label="Confidence Scores") ], title="Alzheimer's Disease MRI Image Classifier", description="Upload an MRI image to predict the stage of Alzheimer's disease." ) iface.launch()