File size: 2,370 Bytes
2208322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from torch.nn.functional import interpolate

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = []
        self.gradients = []

        # Register hooks
        target_layer.register_forward_hook(self.save_activations)
        target_layer.register_backward_hook(self.save_gradients)

    def save_activations(self, module, input, output):
        self.activations.append(output.detach())

    def save_gradients(self, module, grad_input, grad_output):
        self.gradients.append(grad_output[0].detach())

    def forward(self, input_tensor):
        return self.model(input_tensor)

    def generate(self, input_tensor, target_class):
        # Forward pass
        output = self.forward(input_tensor)

        # Backward pass for specific class
        self.model.zero_grad()
        loss = output[:, target_class].mean()
        loss.backward(retain_graph=True)

        # Get activations and gradients
        activations = self.activations[0].cpu().data.numpy()[0]
        gradients = self.gradients[0].cpu().data.numpy()[0]

        # Compute weights
        weights = np.mean(gradients, axis=(1, 2))

        # Create CAM
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i, :, :]

        # Post-process CAM
        cam = np.maximum(cam, 0)
        cam = interpolate(torch.from_numpy(cam[None, None]),
                        size=(224, 224), mode='bilinear').numpy()
        cam = cam.squeeze()
        if cam.max() != 0:
            cam /= cam.max()

        return cam

def generate_gradcam(image, target_class, model, target_layer):
    # Preprocess image
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])

    if not isinstance(image, torch.Tensor):
        image = preprocess(image)

    image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device)

    # Initialize Grad-CAM
    gradcam = GradCAM(model, target_layer)

    # Generate CAM
    image = image.to(device) 

    cam = gradcam.generate(image_preprocessed, target_class)
    return cam