import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torchvision.transforms.functional import to_pil_image import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Dataset from PIL import Image import os import numpy as np import warnings from transformers import AutoProcessor, CLIPModel import cv2 import re from huggingface_hub import hf_hub_download import io warnings.filterwarnings("ignore", category=UserWarning) class ImageDataset(Dataset): def __init__(self, image, transform=None, face_only=True, dataset_name=None): # Modified to accept a single PIL image instead of a list of paths self.image = image self.transform = transform self.face_only = face_only self.dataset_name = dataset_name # Load face detector self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') def __len__(self): return 1 # Only one image def detect_face(self, image_np): """Detect face in image and return the face region""" gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) faces = self.face_detector.detectMultiScale(gray, 1.1, 5) # If no face is detected, use the whole image if len(faces) == 0: print("No face detected, using whole image") h, w = image_np.shape[:2] return (0, 0, w, h), image_np # Get the largest face if len(faces) > 1: # Choose the largest face by area areas = [w*h for (x, y, w, h) in faces] largest_idx = np.argmax(areas) x, y, w, h = faces[largest_idx] else: x, y, w, h = faces[0] # Add padding around the face (5% on each side - reduced padding) padding_x = int(w * 0.05) padding_y = int(h * 0.05) # Ensure padding doesn't go outside image bounds x1 = max(0, x - padding_x) y1 = max(0, y - padding_y) x2 = min(image_np.shape[1], x + w + padding_x) y2 = min(image_np.shape[0], y + h + padding_y) # Extract the face region face_img = image_np[y1:y2, x1:x2] return (x1, y1, x2-x1, y2-y1), face_img def __getitem__(self, idx): # Use the single image provided image_np = np.array(self.image) label = 0 # Default label; will be overridden by prediction in app.py # Store original image for visualization original_image = self.image.copy() # Detect face if required if self.face_only: face_box, face_img_np = self.detect_face(image_np) face_img = Image.fromarray(face_img_np) # Apply transform to face image if self.transform: face_tensor = self.transform(face_img) else: face_tensor = transforms.ToTensor()(face_img) return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name else: # Process the whole image if self.transform: image_tensor = self.transform(self.image) else: image_tensor = transforms.ToTensor()(self.image) return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None self._register_hooks() def _register_hooks(self): def forward_hook(module, input, output): if isinstance(output, tuple): self.activations = output[0] else: self.activations = output def backward_hook(module, grad_in, grad_out): if isinstance(grad_out, tuple): self.gradients = grad_out[0] else: self.gradients = grad_out layer = dict([*self.model.named_modules()])[self.target_layer] layer.register_forward_hook(forward_hook) layer.register_backward_hook(backward_hook) def generate(self, input_tensor, class_idx): self.model.zero_grad() try: # Use only the vision part of the model for gradient calculation vision_outputs = self.model.vision_model(pixel_values=input_tensor) # Get the pooler output features = vision_outputs.pooler_output # Create a dummy gradient for the feature based on the class idx one_hot = torch.zeros_like(features) one_hot[0, class_idx] = 1 # Manually backpropagate features.backward(gradient=one_hot) # Check for None values if self.gradients is None or self.activations is None: print("Warning: Gradients or activations are None. Using fallback CAM.") return np.ones((14, 14), dtype=np.float32) * 0.5 # Process gradients and activations if len(self.gradients.shape) == 4: # Expected shape for convolutional layers gradients = self.gradients.cpu().detach().numpy() activations = self.activations.cpu().detach().numpy() weights = np.mean(gradients, axis=(2, 3)) cam = np.zeros(activations.shape[2:], dtype=np.float32) for i, w in enumerate(weights[0]): cam += w * activations[0, i, :, :] else: # Handle transformer model format gradients = self.gradients.cpu().detach().numpy() activations = self.activations.cpu().detach().numpy() if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim] seq_len = activations.shape[1] # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197 if seq_len == 197: # Skip the class token (first token) and reshape the patch tokens into a square patch_tokens = activations[0, 1:, :] # Remove the class token # Take the mean across the hidden dimension token_importance = np.mean(np.abs(patch_tokens), axis=1) # Reshape to the expected grid size (14×14 for CLIP ViT-B/16) cam = token_importance.reshape(14, 14) else: # Try to find factors close to a square side_len = int(np.sqrt(seq_len)) # Use the mean across features as importance token_importance = np.mean(np.abs(activations[0]), axis=1) # Create as square-like shape as possible cam = np.zeros((side_len, side_len)) # Fill the cam with available values flat_cam = cam.flatten() flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))] cam = flat_cam.reshape(side_len, side_len) else: # Fallback print("Using fallback CAM shape (14x14)") cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback # Ensure we have valid values if cam is None or cam.size == 0: print("Warning: Generated CAM is empty. Using fallback.") cam = np.ones((14, 14), dtype=np.float32) * 0.5 cam = np.maximum(cam, 0) if np.max(cam) > 0: cam = cam / np.max(cam) return cam except Exception as e: print(f"Error in GradCAM.generate: {str(e)}") return np.ones((14, 14), dtype=np.float32) * 0.5 def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5): if face_box is not None: x, y, w, h = face_box # Create a mask for the entire image (all zeros initially) img_np = np.array(image) full_h, full_w = img_np.shape[:2] full_cam = np.zeros((full_h, full_w), dtype=np.float32) # Resize CAM to match face region face_cam = cv2.resize(cam, (w, h)) # Copy the face CAM into the full image CAM at the face position full_cam[y:y+h, x:x+w] = face_cam # Convert full CAM to image cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8)) cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap cam_colormap = (cam_colormap * 255).astype(np.uint8) else: cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR) cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap cam_colormap = (cam_colormap * 255).astype(np.uint8) blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha) return blended def save_comparison(image, cam, overlay, face_box=None): fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Original Image axes[0].imshow(image) axes[0].set_title("Original") if face_box is not None: x, y, w, h = face_box rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False) axes[0].add_patch(rect) axes[0].axis("off") # CAM if face_box is not None: # Create a full image CAM that highlights only the face img_np = np.array(image) h, w = img_np.shape[:2] full_cam = np.zeros((h, w)) x, y, fw, fh = face_box # Resize CAM to face size face_cam = cv2.resize(cam, (fw, fh)) # Place it in the right position full_cam[y:y+fh, x:x+fw] = face_cam axes[1].imshow(full_cam, cmap="jet") else: axes[1].imshow(cam, cmap="jet") axes[1].set_title("CAM") axes[1].axis("off") # Overlay axes[2].imshow(overlay) axes[2].set_title("Overlay") axes[2].axis("off") plt.tight_layout() # Convert plot to PIL Image for Streamlit display buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close() buf.seek(0) return Image.open(buf) def load_clip_model(): # Modified to load checkpoint from Hugging Face model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth") checkpoint = torch.load(checkpoint_path, map_location='cpu') model_dict = model.state_dict() checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape} model_dict.update(checkpoint) model.load_state_dict(model_dict) model.eval() return model, processor def get_target_layer_clip(model): # For CLIP ViT large, use a layer that will have activations in the right format return "vision_model.encoder.layers.23" def process_images(dataloader, model, cam_extractor, device, pred_class): # Modified to process a single image and return results for Streamlit for batch in dataloader: input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch original_image = original_images[0] face_box = face_boxes[0] print(f"Processing uploaded image...") # Move tensors and model to device input_tensor = input_tensor.to(device) model = model.to(device) try: # Forward pass and Grad-CAM generation output = model.vision_model(pixel_values=input_tensor).pooler_output class_idx = pred_class # Use predicted class from app.py cam = cam_extractor.generate(input_tensor, class_idx) # Generate CAM image if face_box is not None: x, y, w, h = face_box img_np = np.array(original_image) h_full, w_full = img_np.shape[:2] full_cam = np.zeros((h_full, w_full)) face_cam = cv2.resize(cam, (w, h)) full_cam[y:y+h, x:x+w] = face_cam cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8)) else: cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] cam_colormap = (cam_colormap * 255).astype(np.uint8) cam_img = Image.fromarray(cam_colormap) # Generate Overlay overlay = overlay_cam_on_image(original_image, cam, face_box) # Generate Comparison comparison = save_comparison(original_image, cam, overlay, face_box) return cam, cam_img, overlay, comparison except Exception as e: print(f"Error processing image: {str(e)}") import traceback traceback.print_exc() # Return default values in case of error default_cam = np.ones((14, 14), dtype=np.float32) * 0.5 cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] cam_colormap = (cam_colormap * 255).astype(np.uint8) cam_img = Image.fromarray(cam_colormap) overlay = overlay_cam_on_image(original_image, default_cam, face_box) comparison = save_comparison(original_image, default_cam, overlay, face_box) return default_cam, cam_img, overlay, comparison