File size: 2,205 Bytes
18841bc
d8df7f2
9865f77
 
d8df7f2
9865f77
d8df7f2
767a524
9865f77
d8df7f2
9865f77
 
 
 
4ecb9f3
 
9865f77
 
 
 
 
 
 
 
 
 
 
d8df7f2
 
4ecb9f3
 
 
 
 
 
9865f77
d8df7f2
 
9865f77
d8df7f2
9865f77
4ecb9f3
 
 
9865f77
4ecb9f3
9865f77
 
 
4ecb9f3
9865f77
 
 
4ecb9f3
9865f77
 
4ecb9f3
9865f77
d8df7f2
 
9865f77
 
4ecb9f3
9865f77
18841bc
9865f77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import numpy as np
from PIL import Image

# Load the model from the Hub
model_name = "devadethanr/alz_model" 
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)

# Get the label names from the model's configuration
labels = model.config.id2label

print("Labels:", labels)  # Debugging statement to check the labels

# Define the prediction function (with preprocessing)
def predict_image(image):
    """
    Predicts the Alzheimer's disease stage from an uploaded MRI image.

    Args:
        image: The uploaded MRI image (PIL Image).

    Returns:
        The predicted label with its corresponding probability.
    """
    # Preprocessing steps:
    image = np.array(image)
    if image.ndim == 2:  # Convert grayscale to RGB if needed
        image = np.repeat(image[:, :, np.newaxis], 3, axis=2)

    # Resize image if necessary (optional step)
    if image.shape[0] != 224 or image.shape[1] != 224:
        image = np.array(Image.fromarray(image).resize((224, 224)))

    # Model inference:
    inputs = processor(images=image, return_tensors="pt").to(model.device)
    with torch.no_grad():
        logits = model(**inputs).logits

    print(f"logits shape: {logits.shape}")  # Debugging statement to check shape
    print(f"logits: {logits}")  # Debugging statement to check content

    predicted_label_id = logits.argmax(-1).item()
    predicted_label = labels[predicted_label_id]

    # Calculate probabilities using softmax
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    confidences = {labels[i]: float(probabilities[0][i]) for i in range(len(labels))}

    return predicted_label, confidences

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload MRI Image"),
    outputs=[
        gr.Label(label="Prediction"),
        gr.JSON(label="Confidence Scores")
    ],
    title="Alzheimer's Disease MRI Image Classifier",
    description="Upload an MRI image to predict the stage of Alzheimer's disease."
)

iface.launch()