saakshigupta commited on
Commit
e5c0da6
·
verified ·
1 Parent(s): 43ab796

Upload gradcam_clip_large-2.py

Browse files
Files changed (1) hide show
  1. gradcam_clip_large-2.py +345 -0
gradcam_clip_large-2.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from torchvision.transforms.functional import to_pil_image
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from PIL import Image
9
+ import os
10
+ import numpy as np
11
+ import warnings
12
+ from transformers import AutoProcessor, CLIPModel
13
+ import cv2
14
+ import re
15
+ from huggingface_hub import hf_hub_download
16
+ import io
17
+ warnings.filterwarnings("ignore", category=UserWarning)
18
+
19
+ class ImageDataset(Dataset):
20
+ def __init__(self, image, transform=None, face_only=True, dataset_name=None):
21
+ # Modified to accept a single PIL image instead of a list of paths
22
+ self.image = image
23
+ self.transform = transform
24
+ self.face_only = face_only
25
+ self.dataset_name = dataset_name
26
+ # Load face detector
27
+ self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
28
+
29
+ def __len__(self):
30
+ return 1 # Only one image
31
+
32
+ def detect_face(self, image_np):
33
+ """Detect face in image and return the face region"""
34
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
35
+ faces = self.face_detector.detectMultiScale(gray, 1.1, 5)
36
+
37
+ # If no face is detected, use the whole image
38
+ if len(faces) == 0:
39
+ print("No face detected, using whole image")
40
+ h, w = image_np.shape[:2]
41
+ return (0, 0, w, h), image_np
42
+
43
+ # Get the largest face
44
+ if len(faces) > 1:
45
+ # Choose the largest face by area
46
+ areas = [w*h for (x, y, w, h) in faces]
47
+ largest_idx = np.argmax(areas)
48
+ x, y, w, h = faces[largest_idx]
49
+ else:
50
+ x, y, w, h = faces[0]
51
+
52
+ # Add padding around the face (5% on each side - reduced padding)
53
+ padding_x = int(w * 0.05)
54
+ padding_y = int(h * 0.05)
55
+
56
+ # Ensure padding doesn't go outside image bounds
57
+ x1 = max(0, x - padding_x)
58
+ y1 = max(0, y - padding_y)
59
+ x2 = min(image_np.shape[1], x + w + padding_x)
60
+ y2 = min(image_np.shape[0], y + h + padding_y)
61
+
62
+ # Extract the face region
63
+ face_img = image_np[y1:y2, x1:x2]
64
+
65
+ return (x1, y1, x2-x1, y2-y1), face_img
66
+
67
+ def __getitem__(self, idx):
68
+ # Use the single image provided
69
+ image_np = np.array(self.image)
70
+ label = 0 # Default label; will be overridden by prediction in app.py
71
+
72
+ # Store original image for visualization
73
+ original_image = self.image.copy()
74
+
75
+ # Detect face if required
76
+ if self.face_only:
77
+ face_box, face_img_np = self.detect_face(image_np)
78
+ face_img = Image.fromarray(face_img_np)
79
+
80
+ # Apply transform to face image
81
+ if self.transform:
82
+ face_tensor = self.transform(face_img)
83
+ else:
84
+ face_tensor = transforms.ToTensor()(face_img)
85
+
86
+ return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name
87
+ else:
88
+ # Process the whole image
89
+ if self.transform:
90
+ image_tensor = self.transform(self.image)
91
+ else:
92
+ image_tensor = transforms.ToTensor()(self.image)
93
+
94
+ return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
95
+
96
+ class GradCAM:
97
+ def __init__(self, model, target_layer):
98
+ self.model = model
99
+ self.target_layer = target_layer
100
+ self.gradients = None
101
+ self.activations = None
102
+ self._register_hooks()
103
+
104
+ def _register_hooks(self):
105
+ def forward_hook(module, input, output):
106
+ if isinstance(output, tuple):
107
+ self.activations = output[0]
108
+ else:
109
+ self.activations = output
110
+
111
+ def backward_hook(module, grad_in, grad_out):
112
+ if isinstance(grad_out, tuple):
113
+ self.gradients = grad_out[0]
114
+ else:
115
+ self.gradients = grad_out
116
+
117
+ layer = dict([*self.model.named_modules()])[self.target_layer]
118
+ layer.register_forward_hook(forward_hook)
119
+ layer.register_backward_hook(backward_hook)
120
+
121
+ def generate(self, input_tensor, class_idx):
122
+ self.model.zero_grad()
123
+
124
+ try:
125
+ # Use only the vision part of the model for gradient calculation
126
+ vision_outputs = self.model.vision_model(pixel_values=input_tensor)
127
+
128
+ # Get the pooler output
129
+ features = vision_outputs.pooler_output
130
+
131
+ # Create a dummy gradient for the feature based on the class idx
132
+ one_hot = torch.zeros_like(features)
133
+ one_hot[0, class_idx] = 1
134
+
135
+ # Manually backpropagate
136
+ features.backward(gradient=one_hot)
137
+
138
+ # Check for None values
139
+ if self.gradients is None or self.activations is None:
140
+ print("Warning: Gradients or activations are None. Using fallback CAM.")
141
+ return np.ones((14, 14), dtype=np.float32) * 0.5
142
+
143
+ # Process gradients and activations
144
+ if len(self.gradients.shape) == 4: # Expected shape for convolutional layers
145
+ gradients = self.gradients.cpu().detach().numpy()
146
+ activations = self.activations.cpu().detach().numpy()
147
+
148
+ weights = np.mean(gradients, axis=(2, 3))
149
+ cam = np.zeros(activations.shape[2:], dtype=np.float32)
150
+
151
+ for i, w in enumerate(weights[0]):
152
+ cam += w * activations[0, i, :, :]
153
+ else:
154
+ # Handle transformer model format
155
+ gradients = self.gradients.cpu().detach().numpy()
156
+ activations = self.activations.cpu().detach().numpy()
157
+
158
+ if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
159
+ seq_len = activations.shape[1]
160
+
161
+ # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
162
+ if seq_len == 197:
163
+ # Skip the class token (first token) and reshape the patch tokens into a square
164
+ patch_tokens = activations[0, 1:, :] # Remove the class token
165
+ # Take the mean across the hidden dimension
166
+ token_importance = np.mean(np.abs(patch_tokens), axis=1)
167
+ # Reshape to the expected grid size (14×14 for CLIP ViT-B/16)
168
+ cam = token_importance.reshape(14, 14)
169
+ else:
170
+ # Try to find factors close to a square
171
+ side_len = int(np.sqrt(seq_len))
172
+ # Use the mean across features as importance
173
+ token_importance = np.mean(np.abs(activations[0]), axis=1)
174
+ # Create as square-like shape as possible
175
+ cam = np.zeros((side_len, side_len))
176
+ # Fill the cam with available values
177
+ flat_cam = cam.flatten()
178
+ flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
179
+ cam = flat_cam.reshape(side_len, side_len)
180
+ else:
181
+ # Fallback
182
+ print("Using fallback CAM shape (14x14)")
183
+ cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
184
+
185
+ # Ensure we have valid values
186
+ if cam is None or cam.size == 0:
187
+ print("Warning: Generated CAM is empty. Using fallback.")
188
+ cam = np.ones((14, 14), dtype=np.float32) * 0.5
189
+
190
+ cam = np.maximum(cam, 0)
191
+ if np.max(cam) > 0:
192
+ cam = cam / np.max(cam)
193
+
194
+ return cam
195
+
196
+ except Exception as e:
197
+ print(f"Error in GradCAM.generate: {str(e)}")
198
+ return np.ones((14, 14), dtype=np.float32) * 0.5
199
+
200
+ def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
201
+ if face_box is not None:
202
+ x, y, w, h = face_box
203
+ # Create a mask for the entire image (all zeros initially)
204
+ img_np = np.array(image)
205
+ full_h, full_w = img_np.shape[:2]
206
+ full_cam = np.zeros((full_h, full_w), dtype=np.float32)
207
+
208
+ # Resize CAM to match face region
209
+ face_cam = cv2.resize(cam, (w, h))
210
+
211
+ # Copy the face CAM into the full image CAM at the face position
212
+ full_cam[y:y+h, x:x+w] = face_cam
213
+
214
+ # Convert full CAM to image
215
+ cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8))
216
+ cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap
217
+ cam_colormap = (cam_colormap * 255).astype(np.uint8)
218
+ else:
219
+ cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR)
220
+ cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap
221
+ cam_colormap = (cam_colormap * 255).astype(np.uint8)
222
+
223
+ blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha)
224
+ return blended
225
+
226
+ def save_comparison(image, cam, overlay, face_box=None):
227
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
228
+
229
+ # Original Image
230
+ axes[0].imshow(image)
231
+ axes[0].set_title("Original")
232
+ if face_box is not None:
233
+ x, y, w, h = face_box
234
+ rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False)
235
+ axes[0].add_patch(rect)
236
+ axes[0].axis("off")
237
+
238
+ # CAM
239
+ if face_box is not None:
240
+ # Create a full image CAM that highlights only the face
241
+ img_np = np.array(image)
242
+ h, w = img_np.shape[:2]
243
+ full_cam = np.zeros((h, w))
244
+
245
+ x, y, fw, fh = face_box
246
+ # Resize CAM to face size
247
+ face_cam = cv2.resize(cam, (fw, fh))
248
+ # Place it in the right position
249
+ full_cam[y:y+fh, x:x+fw] = face_cam
250
+ axes[1].imshow(full_cam, cmap="jet")
251
+ else:
252
+ axes[1].imshow(cam, cmap="jet")
253
+ axes[1].set_title("CAM")
254
+ axes[1].axis("off")
255
+
256
+ # Overlay
257
+ axes[2].imshow(overlay)
258
+ axes[2].set_title("Overlay")
259
+ axes[2].axis("off")
260
+
261
+ plt.tight_layout()
262
+
263
+ # Convert plot to PIL Image for Streamlit display
264
+ buf = io.BytesIO()
265
+ plt.savefig(buf, format="png", bbox_inches="tight")
266
+ plt.close()
267
+ buf.seek(0)
268
+ return Image.open(buf)
269
+
270
+ def load_clip_model():
271
+ # Modified to load checkpoint from Hugging Face
272
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
273
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
274
+
275
+ checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth")
276
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
277
+
278
+ model_dict = model.state_dict()
279
+ checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape}
280
+
281
+ model_dict.update(checkpoint)
282
+ model.load_state_dict(model_dict)
283
+
284
+ model.eval()
285
+ return model, processor
286
+
287
+ def get_target_layer_clip(model):
288
+ # For CLIP ViT large, use a layer that will have activations in the right format
289
+ return "vision_model.encoder.layers.23"
290
+
291
+ def process_images(dataloader, model, cam_extractor, device, pred_class):
292
+ # Modified to process a single image and return results for Streamlit
293
+ for batch in dataloader:
294
+ input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch
295
+ original_image = original_images[0]
296
+ face_box = face_boxes[0]
297
+
298
+ print(f"Processing uploaded image...")
299
+
300
+ # Move tensors and model to device
301
+ input_tensor = input_tensor.to(device)
302
+ model = model.to(device)
303
+
304
+ try:
305
+ # Forward pass and Grad-CAM generation
306
+ output = model.vision_model(pixel_values=input_tensor).pooler_output
307
+ class_idx = pred_class # Use predicted class from app.py
308
+ cam = cam_extractor.generate(input_tensor, class_idx)
309
+
310
+ # Generate CAM image
311
+ if face_box is not None:
312
+ x, y, w, h = face_box
313
+ img_np = np.array(original_image)
314
+ h_full, w_full = img_np.shape[:2]
315
+ full_cam = np.zeros((h_full, w_full))
316
+ face_cam = cv2.resize(cam, (w, h))
317
+ full_cam[y:y+h, x:x+w] = face_cam
318
+ cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8))
319
+ else:
320
+ cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
321
+ cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
322
+ cam_colormap = (cam_colormap * 255).astype(np.uint8)
323
+ cam_img = Image.fromarray(cam_colormap)
324
+
325
+ # Generate Overlay
326
+ overlay = overlay_cam_on_image(original_image, cam, face_box)
327
+
328
+ # Generate Comparison
329
+ comparison = save_comparison(original_image, cam, overlay, face_box)
330
+
331
+ return cam, cam_img, overlay, comparison
332
+
333
+ except Exception as e:
334
+ print(f"Error processing image: {str(e)}")
335
+ import traceback
336
+ traceback.print_exc()
337
+ # Return default values in case of error
338
+ default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
339
+ cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
340
+ cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
341
+ cam_colormap = (cam_colormap * 255).astype(np.uint8)
342
+ cam_img = Image.fromarray(cam_colormap)
343
+ overlay = overlay_cam_on_image(original_image, default_cam, face_box)
344
+ comparison = save_comparison(original_image, default_cam, overlay, face_box)
345
+ return default_cam, cam_img, overlay, comparison