Spaces:
Sleeping
Sleeping
File size: 2,370 Bytes
2208322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from torch.nn.functional import interpolate
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.activations = []
self.gradients = []
# Register hooks
target_layer.register_forward_hook(self.save_activations)
target_layer.register_backward_hook(self.save_gradients)
def save_activations(self, module, input, output):
self.activations.append(output.detach())
def save_gradients(self, module, grad_input, grad_output):
self.gradients.append(grad_output[0].detach())
def forward(self, input_tensor):
return self.model(input_tensor)
def generate(self, input_tensor, target_class):
# Forward pass
output = self.forward(input_tensor)
# Backward pass for specific class
self.model.zero_grad()
loss = output[:, target_class].mean()
loss.backward(retain_graph=True)
# Get activations and gradients
activations = self.activations[0].cpu().data.numpy()[0]
gradients = self.gradients[0].cpu().data.numpy()[0]
# Compute weights
weights = np.mean(gradients, axis=(1, 2))
# Create CAM
cam = np.zeros(activations.shape[1:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * activations[i, :, :]
# Post-process CAM
cam = np.maximum(cam, 0)
cam = interpolate(torch.from_numpy(cam[None, None]),
size=(224, 224), mode='bilinear').numpy()
cam = cam.squeeze()
if cam.max() != 0:
cam /= cam.max()
return cam
def generate_gradcam(image, target_class, model, target_layer):
# Preprocess image
preprocess = transforms.Compose([
transforms.ToTensor(),
])
if not isinstance(image, torch.Tensor):
image = preprocess(image)
image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device)
# Initialize Grad-CAM
gradcam = GradCAM(model, target_layer)
# Generate CAM
image = image.to(device)
cam = gradcam.generate(image_preprocessed, target_class)
return cam
|