File size: 4,428 Bytes
d8a6e20
 
 
 
 
 
 
 
 
50a18a1
 
 
 
d8a6e20
27b2c9b
e873a09
 
27b2c9b
 
 
e873a09
27b2c9b
d8a6e20
 
50a18a1
9131c16
 
 
 
50a18a1
 
 
 
 
27b2c9b
50a18a1
 
 
27b2c9b
 
 
 
 
 
 
 
 
 
 
 
50a18a1
9131c16
50a18a1
9131c16
 
 
27b2c9b
50a18a1
d8a6e20
50a18a1
9131c16
 
 
 
50a18a1
 
 
 
 
 
 
 
 
9131c16
 
 
 
 
 
 
 
 
50a18a1
27b2c9b
9131c16
 
 
 
 
 
 
27b2c9b
9131c16
 
 
 
27b2c9b
50a18a1
 
27b2c9b
9131c16
50a18a1
 
9131c16
 
 
 
 
 
 
27b2c9b
4ef367f
50a18a1
4ef367f
 
9131c16
7cfbbcd
 
4ef367f
 
9131c16
d8a6e20
 
 
27b2c9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt

# Load the pretrained Vision Transformer model and image processor
model_name = "google/vit-base-patch16-224"
try:
    model = ViTForImageClassification.from_pretrained(model_name)
except Exception as e:
    print(f"Error loading model: {e}")
image_processor = ViTImageProcessor.from_pretrained(model_name)

# NIH Chest X-ray predefined conditions
labels = [
    "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
    "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", 
    "Fibrosis", "Pleural Thickening", "Hernia"
]

# Function to apply Grad-CAM visualization
def generate_grad_cam(image, target_layer):
    try:
        # Convert image to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Preprocess the image
        inputs = image_processor(images=image, return_tensors="pt")
        
        # Forward pass to get logits
        input_tensor = inputs["pixel_values"]
        input_tensor.requires_grad = True  # Enable gradient tracking
        outputs = model(input_tensor)
        logits = outputs.logits
        
        # Get the predicted class and calculate gradients
        predicted_class = logits.argmax(-1)
        class_score = logits[0, predicted_class]
        class_score.backward()
        
        # Get gradients and weights from the target layer
        gradients = model.get_input_embeddings().weight.grad
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
        
        # Apply Grad-CAM calculation (modify this part as per the model architecture)
        cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
        cam = torch.clamp(cam, min=0).numpy()  # Ensure non-negative values
        
        return cam, predicted_class.item(), None  # No error
    except Exception as e:
        error_message = f"Error generating Grad-CAM: {e}"
        print(error_message)
        return None, None, error_message

# Function to predict classes and visualize Grad-CAM
def predict_and_explain(image):
    try:
        # Convert image to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Preprocess the image
        inputs = image_processor(images=image, return_tensors="pt")
        
        # Forward pass to get logits
        input_tensor = inputs["pixel_values"]
        outputs = model(input_tensor)
        logits = outputs.logits
        
        predicted_class = logits.argmax(-1).item()
        cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output")
        
        # Check for Grad-CAM errors
        if grad_cam_error is not None:
            return {
                "predicted class": "Error during Grad-CAM generation",
                "Grad-CAM map": None,
                "error log": grad_cam_error
            }
        
        # Convert cam_map to a visualizable format (heatmap)
        if cam_map is not None:
            plt.imshow(cam_map, cmap='jet', alpha=0.5)
            plt.axis('off')
            plt.title(f"Grad-CAM for {labels[predicted_class]}")
            plt.colorbar()
            plt.savefig("grad_cam_output.png")
            plt.close()
        
            # Load the saved image to return it
            grad_cam_image = Image.open("grad_cam_output.png")
        else:
            grad_cam_image = None
        
        return {
            "predicted class": labels[predicted_class],
            "Grad-CAM map": grad_cam_image,
            "error log": "No errors"
        }
    except Exception as e:
        error_message = f"Error predicting and explaining: {e}"
        print(error_message)
        return {
            "predicted class": "Error during prediction",
            "Grad-CAM map": None,
            "error log": error_message
        }

# Use the updated Gradio components syntax
iface = gr.Interface(
    fn=predict_and_explain,
    inputs=gr.Image(type="pil"),  # Proper input type for images
    outputs=[
        gr.Textbox(label="Predicted Class"),
        gr.Image(label="Grad-CAM Map"),
        gr.Textbox(label="Error Log")
    ],
    title="Chest X-ray Classification with Debugging Logs"
)

if __name__ == "__main__":
    iface.launch()