Spaces:
Sleeping
Sleeping
File size: 4,428 Bytes
d8a6e20 50a18a1 d8a6e20 27b2c9b e873a09 27b2c9b e873a09 27b2c9b d8a6e20 50a18a1 9131c16 50a18a1 27b2c9b 50a18a1 27b2c9b 50a18a1 9131c16 50a18a1 9131c16 27b2c9b 50a18a1 d8a6e20 50a18a1 9131c16 50a18a1 9131c16 50a18a1 27b2c9b 9131c16 27b2c9b 9131c16 27b2c9b 50a18a1 27b2c9b 9131c16 50a18a1 9131c16 27b2c9b 4ef367f 50a18a1 4ef367f 9131c16 7cfbbcd 4ef367f 9131c16 d8a6e20 27b2c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the pretrained Vision Transformer model and image processor
model_name = "google/vit-base-patch16-224"
try:
model = ViTForImageClassification.from_pretrained(model_name)
except Exception as e:
print(f"Error loading model: {e}")
image_processor = ViTImageProcessor.from_pretrained(model_name)
# NIH Chest X-ray predefined conditions
labels = [
"Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
"Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema",
"Fibrosis", "Pleural Thickening", "Hernia"
]
# Function to apply Grad-CAM visualization
def generate_grad_cam(image, target_layer):
try:
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess the image
inputs = image_processor(images=image, return_tensors="pt")
# Forward pass to get logits
input_tensor = inputs["pixel_values"]
input_tensor.requires_grad = True # Enable gradient tracking
outputs = model(input_tensor)
logits = outputs.logits
# Get the predicted class and calculate gradients
predicted_class = logits.argmax(-1)
class_score = logits[0, predicted_class]
class_score.backward()
# Get gradients and weights from the target layer
gradients = model.get_input_embeddings().weight.grad
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Apply Grad-CAM calculation (modify this part as per the model architecture)
cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
return cam, predicted_class.item(), None # No error
except Exception as e:
error_message = f"Error generating Grad-CAM: {e}"
print(error_message)
return None, None, error_message
# Function to predict classes and visualize Grad-CAM
def predict_and_explain(image):
try:
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess the image
inputs = image_processor(images=image, return_tensors="pt")
# Forward pass to get logits
input_tensor = inputs["pixel_values"]
outputs = model(input_tensor)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output")
# Check for Grad-CAM errors
if grad_cam_error is not None:
return {
"predicted class": "Error during Grad-CAM generation",
"Grad-CAM map": None,
"error log": grad_cam_error
}
# Convert cam_map to a visualizable format (heatmap)
if cam_map is not None:
plt.imshow(cam_map, cmap='jet', alpha=0.5)
plt.axis('off')
plt.title(f"Grad-CAM for {labels[predicted_class]}")
plt.colorbar()
plt.savefig("grad_cam_output.png")
plt.close()
# Load the saved image to return it
grad_cam_image = Image.open("grad_cam_output.png")
else:
grad_cam_image = None
return {
"predicted class": labels[predicted_class],
"Grad-CAM map": grad_cam_image,
"error log": "No errors"
}
except Exception as e:
error_message = f"Error predicting and explaining: {e}"
print(error_message)
return {
"predicted class": "Error during prediction",
"Grad-CAM map": None,
"error log": error_message
}
# Use the updated Gradio components syntax
iface = gr.Interface(
fn=predict_and_explain,
inputs=gr.Image(type="pil"), # Proper input type for images
outputs=[
gr.Textbox(label="Predicted Class"),
gr.Image(label="Grad-CAM Map"),
gr.Textbox(label="Error Log")
],
title="Chest X-ray Classification with Debugging Logs"
)
if __name__ == "__main__":
iface.launch()
|