Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Load the pretrained Vision Transformer model and image processor | |
model_name = "google/vit-base-patch16-224" | |
try: | |
model = ViTForImageClassification.from_pretrained(model_name) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
image_processor = ViTImageProcessor.from_pretrained(model_name) | |
# NIH Chest X-ray predefined conditions | |
labels = [ | |
"Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", | |
"Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", | |
"Fibrosis", "Pleural Thickening", "Hernia" | |
] | |
# Function to apply Grad-CAM visualization | |
def generate_grad_cam(image, target_layer): | |
try: | |
# Convert image to RGB if necessary | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Preprocess the image | |
inputs = image_processor(images=image, return_tensors="pt") | |
# Forward pass to get logits | |
input_tensor = inputs["pixel_values"] | |
input_tensor.requires_grad = True # Enable gradient tracking | |
outputs = model(input_tensor) | |
logits = outputs.logits | |
# Get the predicted class and calculate gradients | |
predicted_class = logits.argmax(-1) | |
class_score = logits[0, predicted_class] | |
class_score.backward() | |
# Get gradients and weights from the target layer | |
gradients = model.get_input_embeddings().weight.grad | |
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) | |
# Apply Grad-CAM calculation (modify this part as per the model architecture) | |
cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze() | |
cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values | |
return cam, predicted_class.item(), None # No error | |
except Exception as e: | |
error_message = f"Error generating Grad-CAM: {e}" | |
print(error_message) | |
return None, None, error_message | |
# Function to predict classes and visualize Grad-CAM | |
def predict_and_explain(image): | |
try: | |
# Convert image to RGB if necessary | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Preprocess the image | |
inputs = image_processor(images=image, return_tensors="pt") | |
# Forward pass to get logits | |
input_tensor = inputs["pixel_values"] | |
outputs = model(input_tensor) | |
logits = outputs.logits | |
predicted_class = logits.argmax(-1).item() | |
cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output") | |
# Check for Grad-CAM errors | |
if grad_cam_error is not None: | |
return { | |
"predicted class": "Error during Grad-CAM generation", | |
"Grad-CAM map": None, | |
"error log": grad_cam_error | |
} | |
# Convert cam_map to a visualizable format (heatmap) | |
if cam_map is not None: | |
plt.imshow(cam_map, cmap='jet', alpha=0.5) | |
plt.axis('off') | |
plt.title(f"Grad-CAM for {labels[predicted_class]}") | |
plt.colorbar() | |
plt.savefig("grad_cam_output.png") | |
plt.close() | |
# Load the saved image to return it | |
grad_cam_image = Image.open("grad_cam_output.png") | |
else: | |
grad_cam_image = None | |
return { | |
"predicted class": labels[predicted_class], | |
"Grad-CAM map": grad_cam_image, | |
"error log": "No errors" | |
} | |
except Exception as e: | |
error_message = f"Error predicting and explaining: {e}" | |
print(error_message) | |
return { | |
"predicted class": "Error during prediction", | |
"Grad-CAM map": None, | |
"error log": error_message | |
} | |
# Use the updated Gradio components syntax | |
iface = gr.Interface( | |
fn=predict_and_explain, | |
inputs=gr.Image(type="pil"), # Proper input type for images | |
outputs=[ | |
gr.Textbox(label="Predicted Class"), | |
gr.Image(label="Grad-CAM Map"), | |
gr.Textbox(label="Error Log") | |
], | |
title="Chest X-ray Classification with Debugging Logs" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |