GreenThumb / app.py
rtik007's picture
Update app.py
4ef367f verified
raw
history blame
4.43 kB
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the pretrained Vision Transformer model and image processor
model_name = "google/vit-base-patch16-224"
try:
model = ViTForImageClassification.from_pretrained(model_name)
except Exception as e:
print(f"Error loading model: {e}")
image_processor = ViTImageProcessor.from_pretrained(model_name)
# NIH Chest X-ray predefined conditions
labels = [
"Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
"Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema",
"Fibrosis", "Pleural Thickening", "Hernia"
]
# Function to apply Grad-CAM visualization
def generate_grad_cam(image, target_layer):
try:
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess the image
inputs = image_processor(images=image, return_tensors="pt")
# Forward pass to get logits
input_tensor = inputs["pixel_values"]
input_tensor.requires_grad = True # Enable gradient tracking
outputs = model(input_tensor)
logits = outputs.logits
# Get the predicted class and calculate gradients
predicted_class = logits.argmax(-1)
class_score = logits[0, predicted_class]
class_score.backward()
# Get gradients and weights from the target layer
gradients = model.get_input_embeddings().weight.grad
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Apply Grad-CAM calculation (modify this part as per the model architecture)
cam = torch.mean(pooled_gradients * inputs["pixel_values"], dim=1).squeeze()
cam = torch.clamp(cam, min=0).numpy() # Ensure non-negative values
return cam, predicted_class.item(), None # No error
except Exception as e:
error_message = f"Error generating Grad-CAM: {e}"
print(error_message)
return None, None, error_message
# Function to predict classes and visualize Grad-CAM
def predict_and_explain(image):
try:
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess the image
inputs = image_processor(images=image, return_tensors="pt")
# Forward pass to get logits
input_tensor = inputs["pixel_values"]
outputs = model(input_tensor)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
cam_map, _, grad_cam_error = generate_grad_cam(image, "pooler_output")
# Check for Grad-CAM errors
if grad_cam_error is not None:
return {
"predicted class": "Error during Grad-CAM generation",
"Grad-CAM map": None,
"error log": grad_cam_error
}
# Convert cam_map to a visualizable format (heatmap)
if cam_map is not None:
plt.imshow(cam_map, cmap='jet', alpha=0.5)
plt.axis('off')
plt.title(f"Grad-CAM for {labels[predicted_class]}")
plt.colorbar()
plt.savefig("grad_cam_output.png")
plt.close()
# Load the saved image to return it
grad_cam_image = Image.open("grad_cam_output.png")
else:
grad_cam_image = None
return {
"predicted class": labels[predicted_class],
"Grad-CAM map": grad_cam_image,
"error log": "No errors"
}
except Exception as e:
error_message = f"Error predicting and explaining: {e}"
print(error_message)
return {
"predicted class": "Error during prediction",
"Grad-CAM map": None,
"error log": error_message
}
# Use the updated Gradio components syntax
iface = gr.Interface(
fn=predict_and_explain,
inputs=gr.Image(type="pil"), # Proper input type for images
outputs=[
gr.Textbox(label="Predicted Class"),
gr.Image(label="Grad-CAM Map"),
gr.Textbox(label="Error Log")
],
title="Chest X-ray Classification with Debugging Logs"
)
if __name__ == "__main__":
iface.launch()