Spaces:
Sleeping
Sleeping
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from torchvision import transforms | |
from torch.nn.functional import interpolate | |
class GradCAM: | |
def __init__(self, model, target_layer): | |
self.model = model | |
self.target_layer = target_layer | |
self.activations = [] | |
self.gradients = [] | |
# Register hooks | |
target_layer.register_forward_hook(self.save_activations) | |
target_layer.register_backward_hook(self.save_gradients) | |
def save_activations(self, module, input, output): | |
self.activations.append(output.detach()) | |
def save_gradients(self, module, grad_input, grad_output): | |
self.gradients.append(grad_output[0].detach()) | |
def forward(self, input_tensor): | |
return self.model(input_tensor) | |
def generate(self, input_tensor, target_class): | |
# Forward pass | |
output = self.forward(input_tensor) | |
# Backward pass for specific class | |
self.model.zero_grad() | |
loss = output[:, target_class].mean() | |
loss.backward(retain_graph=True) | |
# Get activations and gradients | |
activations = self.activations[0].cpu().data.numpy()[0] | |
gradients = self.gradients[0].cpu().data.numpy()[0] | |
# Compute weights | |
weights = np.mean(gradients, axis=(1, 2)) | |
# Create CAM | |
cam = np.zeros(activations.shape[1:], dtype=np.float32) | |
for i, w in enumerate(weights): | |
cam += w * activations[i, :, :] | |
# Post-process CAM | |
cam = np.maximum(cam, 0) | |
cam = interpolate(torch.from_numpy(cam[None, None]), | |
size=(224, 224), mode='bilinear').numpy() | |
cam = cam.squeeze() | |
if cam.max() != 0: | |
cam /= cam.max() | |
return cam | |
def generate_gradcam(image, target_class, model, target_layer): | |
# Preprocess image | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
if not isinstance(image, torch.Tensor): | |
image = preprocess(image) | |
image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device) | |
# Initialize Grad-CAM | |
gradcam = GradCAM(model, target_layer) | |
# Generate CAM | |
image = image.to(device) | |
cam = gradcam.generate(image_preprocessed, target_class) | |
return cam | |