import torch import matplotlib.pyplot as plt import numpy as np from torchvision import transforms from torch.nn.functional import interpolate class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.activations = [] self.gradients = [] # Register hooks target_layer.register_forward_hook(self.save_activations) target_layer.register_backward_hook(self.save_gradients) def save_activations(self, module, input, output): self.activations.append(output.detach()) def save_gradients(self, module, grad_input, grad_output): self.gradients.append(grad_output[0].detach()) def forward(self, input_tensor): return self.model(input_tensor) def generate(self, input_tensor, target_class): # Forward pass output = self.forward(input_tensor) # Backward pass for specific class self.model.zero_grad() loss = output[:, target_class].mean() loss.backward(retain_graph=True) # Get activations and gradients activations = self.activations[0].cpu().data.numpy()[0] gradients = self.gradients[0].cpu().data.numpy()[0] # Compute weights weights = np.mean(gradients, axis=(1, 2)) # Create CAM cam = np.zeros(activations.shape[1:], dtype=np.float32) for i, w in enumerate(weights): cam += w * activations[i, :, :] # Post-process CAM cam = np.maximum(cam, 0) cam = interpolate(torch.from_numpy(cam[None, None]), size=(224, 224), mode='bilinear').numpy() cam = cam.squeeze() if cam.max() != 0: cam /= cam.max() return cam def generate_gradcam(image, target_class, model, target_layer): # Preprocess image preprocess = transforms.Compose([ transforms.ToTensor(), ]) if not isinstance(image, torch.Tensor): image = preprocess(image) image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device) # Initialize Grad-CAM gradcam = GradCAM(model, target_layer) # Generate CAM image = image.to(device) cam = gradcam.generate(image_preprocessed, target_class) return cam