danielle2003 commited on
Commit
ebe3ca3
·
verified ·
1 Parent(s): d600b3f

Upload gradcam.py

Browse files
Files changed (1) hide show
  1. gradcam.py +77 -0
gradcam.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from torch.nn.functional import interpolate
6
+
7
+ class GradCAM:
8
+ def __init__(self, model, target_layer):
9
+ self.model = model
10
+ self.target_layer = target_layer
11
+ self.activations = []
12
+ self.gradients = []
13
+
14
+ # Register hooks
15
+ target_layer.register_forward_hook(self.save_activations)
16
+ target_layer.register_backward_hook(self.save_gradients)
17
+
18
+ def save_activations(self, module, input, output):
19
+ self.activations.append(output.detach())
20
+
21
+ def save_gradients(self, module, grad_input, grad_output):
22
+ self.gradients.append(grad_output[0].detach())
23
+
24
+ def forward(self, input_tensor):
25
+ return self.model(input_tensor)
26
+
27
+ def generate(self, input_tensor, target_class):
28
+ # Forward pass
29
+ output = self.forward(input_tensor)
30
+
31
+ # Backward pass for specific class
32
+ self.model.zero_grad()
33
+ loss = output[:, target_class].mean()
34
+ loss.backward(retain_graph=True)
35
+
36
+ # Get activations and gradients
37
+ activations = self.activations[0].cpu().data.numpy()[0]
38
+ gradients = self.gradients[0].cpu().data.numpy()[0]
39
+
40
+ # Compute weights
41
+ weights = np.mean(gradients, axis=(1, 2))
42
+
43
+ # Create CAM
44
+ cam = np.zeros(activations.shape[1:], dtype=np.float32)
45
+ for i, w in enumerate(weights):
46
+ cam += w * activations[i, :, :]
47
+
48
+ # Post-process CAM
49
+ cam = np.maximum(cam, 0)
50
+ cam = interpolate(torch.from_numpy(cam[None, None]),
51
+ size=(224, 224), mode='bilinear').numpy()
52
+ cam = cam.squeeze()
53
+ if cam.max() != 0:
54
+ cam /= cam.max()
55
+
56
+ return cam
57
+
58
+ def generate_gradcam(image, target_class, model, target_layer):
59
+ # Preprocess image
60
+ preprocess = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ ])
63
+
64
+ if not isinstance(image, torch.Tensor):
65
+ image = preprocess(image)
66
+
67
+ image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device)
68
+
69
+ # Initialize Grad-CAM
70
+ gradcam = GradCAM(model, target_layer)
71
+
72
+ # Generate CAM
73
+ image = image.to(device)
74
+
75
+ cam = gradcam.generate(image_preprocessed, target_class)
76
+ return cam
77
+