File size: 14,256 Bytes
e5c0da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import warnings
from transformers import AutoProcessor, CLIPModel
import cv2
import re
from huggingface_hub import hf_hub_download
import io
warnings.filterwarnings("ignore", category=UserWarning)

class ImageDataset(Dataset):
    def __init__(self, image, transform=None, face_only=True, dataset_name=None):
        # Modified to accept a single PIL image instead of a list of paths
        self.image = image
        self.transform = transform
        self.face_only = face_only
        self.dataset_name = dataset_name
        # Load face detector
        self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        
    def __len__(self):
        return 1  # Only one image

    def detect_face(self, image_np):
        """Detect face in image and return the face region"""
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
        faces = self.face_detector.detectMultiScale(gray, 1.1, 5)
        
        # If no face is detected, use the whole image
        if len(faces) == 0:
            print("No face detected, using whole image")
            h, w = image_np.shape[:2]
            return (0, 0, w, h), image_np
            
        # Get the largest face
        if len(faces) > 1:
            # Choose the largest face by area
            areas = [w*h for (x, y, w, h) in faces]
            largest_idx = np.argmax(areas)
            x, y, w, h = faces[largest_idx]
        else:
            x, y, w, h = faces[0]
            
        # Add padding around the face (5% on each side - reduced padding)
        padding_x = int(w * 0.05)
        padding_y = int(h * 0.05)
        
        # Ensure padding doesn't go outside image bounds
        x1 = max(0, x - padding_x)
        y1 = max(0, y - padding_y)
        x2 = min(image_np.shape[1], x + w + padding_x)
        y2 = min(image_np.shape[0], y + h + padding_y)
        
        # Extract the face region
        face_img = image_np[y1:y2, x1:x2]
        
        return (x1, y1, x2-x1, y2-y1), face_img

    def __getitem__(self, idx):
        # Use the single image provided
        image_np = np.array(self.image)
        label = 0  # Default label; will be overridden by prediction in app.py
        
        # Store original image for visualization
        original_image = self.image.copy()
        
        # Detect face if required
        if self.face_only:
            face_box, face_img_np = self.detect_face(image_np)
            face_img = Image.fromarray(face_img_np)
            
            # Apply transform to face image
            if self.transform:
                face_tensor = self.transform(face_img)
            else:
                face_tensor = transforms.ToTensor()(face_img)
            
            return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name
        else:
            # Process the whole image
            if self.transform:
                image_tensor = self.transform(self.image)
            else:
                image_tensor = transforms.ToTensor()(self.image)
                
            return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            if isinstance(output, tuple):
                self.activations = output[0]
            else:
                self.activations = output

        def backward_hook(module, grad_in, grad_out):
            if isinstance(grad_out, tuple):
                self.gradients = grad_out[0]
            else:
                self.gradients = grad_out

        layer = dict([*self.model.named_modules()])[self.target_layer]
        layer.register_forward_hook(forward_hook)
        layer.register_backward_hook(backward_hook)

    def generate(self, input_tensor, class_idx):
        self.model.zero_grad()
        
        try:
            # Use only the vision part of the model for gradient calculation
            vision_outputs = self.model.vision_model(pixel_values=input_tensor)
            
            # Get the pooler output
            features = vision_outputs.pooler_output
            
            # Create a dummy gradient for the feature based on the class idx
            one_hot = torch.zeros_like(features)
            one_hot[0, class_idx] = 1
            
            # Manually backpropagate
            features.backward(gradient=one_hot)

            # Check for None values
            if self.gradients is None or self.activations is None:
                print("Warning: Gradients or activations are None. Using fallback CAM.")
                return np.ones((14, 14), dtype=np.float32) * 0.5

            # Process gradients and activations
            if len(self.gradients.shape) == 4:  # Expected shape for convolutional layers
                gradients = self.gradients.cpu().detach().numpy()
                activations = self.activations.cpu().detach().numpy()
                
                weights = np.mean(gradients, axis=(2, 3))
                cam = np.zeros(activations.shape[2:], dtype=np.float32)
                
                for i, w in enumerate(weights[0]):
                    cam += w * activations[0, i, :, :]
            else:
                # Handle transformer model format
                gradients = self.gradients.cpu().detach().numpy()
                activations = self.activations.cpu().detach().numpy()
                
                if len(activations.shape) == 3:  # [batch, sequence_length, hidden_dim]
                    seq_len = activations.shape[1]
                    
                    # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
                    if seq_len == 197:
                        # Skip the class token (first token) and reshape the patch tokens into a square
                        patch_tokens = activations[0, 1:, :]  # Remove the class token
                        # Take the mean across the hidden dimension
                        token_importance = np.mean(np.abs(patch_tokens), axis=1)
                        # Reshape to the expected grid size (14×14 for CLIP ViT-B/16)
                        cam = token_importance.reshape(14, 14)
                    else:
                        # Try to find factors close to a square
                        side_len = int(np.sqrt(seq_len))
                        # Use the mean across features as importance
                        token_importance = np.mean(np.abs(activations[0]), axis=1)
                        # Create as square-like shape as possible
                        cam = np.zeros((side_len, side_len))
                        # Fill the cam with available values
                        flat_cam = cam.flatten()
                        flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
                        cam = flat_cam.reshape(side_len, side_len)
                else:
                    # Fallback
                    print("Using fallback CAM shape (14x14)")
                    cam = np.ones((14, 14), dtype=np.float32) * 0.5  # Default fallback

            # Ensure we have valid values
            if cam is None or cam.size == 0:
                print("Warning: Generated CAM is empty. Using fallback.")
                cam = np.ones((14, 14), dtype=np.float32) * 0.5
                
            cam = np.maximum(cam, 0)
            if np.max(cam) > 0:
                cam = cam / np.max(cam)
            
            return cam
            
        except Exception as e:
            print(f"Error in GradCAM.generate: {str(e)}")
            return np.ones((14, 14), dtype=np.float32) * 0.5

def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
    if face_box is not None:
        x, y, w, h = face_box
        # Create a mask for the entire image (all zeros initially)
        img_np = np.array(image)
        full_h, full_w = img_np.shape[:2]
        full_cam = np.zeros((full_h, full_w), dtype=np.float32)
        
        # Resize CAM to match face region
        face_cam = cv2.resize(cam, (w, h))
        
        # Copy the face CAM into the full image CAM at the face position
        full_cam[y:y+h, x:x+w] = face_cam
        
        # Convert full CAM to image
        cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8))
        cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]  # Apply colormap
        cam_colormap = (cam_colormap * 255).astype(np.uint8)
    else:
        cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR)
        cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]  # Apply colormap
        cam_colormap = (cam_colormap * 255).astype(np.uint8)

    blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha)
    return blended

def save_comparison(image, cam, overlay, face_box=None):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original Image
    axes[0].imshow(image)
    axes[0].set_title("Original")
    if face_box is not None:
        x, y, w, h = face_box
        rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False)
        axes[0].add_patch(rect)
    axes[0].axis("off")

    # CAM
    if face_box is not None:
        # Create a full image CAM that highlights only the face
        img_np = np.array(image)
        h, w = img_np.shape[:2]
        full_cam = np.zeros((h, w))
        
        x, y, fw, fh = face_box
        # Resize CAM to face size
        face_cam = cv2.resize(cam, (fw, fh))
        # Place it in the right position
        full_cam[y:y+fh, x:x+fw] = face_cam
        axes[1].imshow(full_cam, cmap="jet")
    else:
        axes[1].imshow(cam, cmap="jet")
    axes[1].set_title("CAM")
    axes[1].axis("off")

    # Overlay
    axes[2].imshow(overlay)
    axes[2].set_title("Overlay")
    axes[2].axis("off")

    plt.tight_layout()
    
    # Convert plot to PIL Image for Streamlit display
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def load_clip_model():
    # Modified to load checkpoint from Hugging Face
    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
    
    checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    model_dict = model.state_dict()
    checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape}
    
    model_dict.update(checkpoint)
    model.load_state_dict(model_dict)
    
    model.eval()
    return model, processor

def get_target_layer_clip(model):
    # For CLIP ViT large, use a layer that will have activations in the right format
    return "vision_model.encoder.layers.23"

def process_images(dataloader, model, cam_extractor, device, pred_class):
    # Modified to process a single image and return results for Streamlit
    for batch in dataloader:
        input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch
        original_image = original_images[0]
        face_box = face_boxes[0]
        
        print(f"Processing uploaded image...")
        
        # Move tensors and model to device
        input_tensor = input_tensor.to(device)
        model = model.to(device)

        try:
            # Forward pass and Grad-CAM generation
            output = model.vision_model(pixel_values=input_tensor).pooler_output
            class_idx = pred_class  # Use predicted class from app.py
            cam = cam_extractor.generate(input_tensor, class_idx)

            # Generate CAM image
            if face_box is not None:
                x, y, w, h = face_box
                img_np = np.array(original_image)
                h_full, w_full = img_np.shape[:2]
                full_cam = np.zeros((h_full, w_full))
                face_cam = cv2.resize(cam, (w, h))
                full_cam[y:y+h, x:x+w] = face_cam
                cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8))
            else:
                cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
                cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
                cam_colormap = (cam_colormap * 255).astype(np.uint8)
                cam_img = Image.fromarray(cam_colormap)

            # Generate Overlay
            overlay = overlay_cam_on_image(original_image, cam, face_box)

            # Generate Comparison
            comparison = save_comparison(original_image, cam, overlay, face_box)

            return cam, cam_img, overlay, comparison
            
        except Exception as e:
            print(f"Error processing image: {str(e)}")
            import traceback
            traceback.print_exc()
            # Return default values in case of error
            default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
            cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
            cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
            cam_colormap = (cam_colormap * 255).astype(np.uint8)
            cam_img = Image.fromarray(cam_colormap)
            overlay = overlay_cam_on_image(original_image, default_cam, face_box)
            comparison = save_comparison(original_image, default_cam, overlay, face_box)
            return default_cam, cam_img, overlay, comparison