deepfake_uq / gradcam_clip_large.py
saakshigupta's picture
Rename gradcam_clip_large-2.py to gradcam_clip_large.py
5b8fcd0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import warnings
from transformers import AutoProcessor, CLIPModel
import cv2
import re
from huggingface_hub import hf_hub_download
import io
warnings.filterwarnings("ignore", category=UserWarning)
class ImageDataset(Dataset):
def __init__(self, image, transform=None, face_only=True, dataset_name=None):
# Modified to accept a single PIL image instead of a list of paths
self.image = image
self.transform = transform
self.face_only = face_only
self.dataset_name = dataset_name
# Load face detector
self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
def __len__(self):
return 1 # Only one image
def detect_face(self, image_np):
"""Detect face in image and return the face region"""
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
faces = self.face_detector.detectMultiScale(gray, 1.1, 5)
# If no face is detected, use the whole image
if len(faces) == 0:
print("No face detected, using whole image")
h, w = image_np.shape[:2]
return (0, 0, w, h), image_np
# Get the largest face
if len(faces) > 1:
# Choose the largest face by area
areas = [w*h for (x, y, w, h) in faces]
largest_idx = np.argmax(areas)
x, y, w, h = faces[largest_idx]
else:
x, y, w, h = faces[0]
# Add padding around the face (5% on each side - reduced padding)
padding_x = int(w * 0.05)
padding_y = int(h * 0.05)
# Ensure padding doesn't go outside image bounds
x1 = max(0, x - padding_x)
y1 = max(0, y - padding_y)
x2 = min(image_np.shape[1], x + w + padding_x)
y2 = min(image_np.shape[0], y + h + padding_y)
# Extract the face region
face_img = image_np[y1:y2, x1:x2]
return (x1, y1, x2-x1, y2-y1), face_img
def __getitem__(self, idx):
# Use the single image provided
image_np = np.array(self.image)
label = 0 # Default label; will be overridden by prediction in app.py
# Store original image for visualization
original_image = self.image.copy()
# Detect face if required
if self.face_only:
face_box, face_img_np = self.detect_face(image_np)
face_img = Image.fromarray(face_img_np)
# Apply transform to face image
if self.transform:
face_tensor = self.transform(face_img)
else:
face_tensor = transforms.ToTensor()(face_img)
return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name
else:
# Process the whole image
if self.transform:
image_tensor = self.transform(self.image)
else:
image_tensor = transforms.ToTensor()(self.image)
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self._register_hooks()
def _register_hooks(self):
def forward_hook(module, input, output):
if isinstance(output, tuple):
self.activations = output[0]
else:
self.activations = output
def backward_hook(module, grad_in, grad_out):
if isinstance(grad_out, tuple):
self.gradients = grad_out[0]
else:
self.gradients = grad_out
layer = dict([*self.model.named_modules()])[self.target_layer]
layer.register_forward_hook(forward_hook)
layer.register_backward_hook(backward_hook)
def generate(self, input_tensor, class_idx):
self.model.zero_grad()
try:
# Use only the vision part of the model for gradient calculation
vision_outputs = self.model.vision_model(pixel_values=input_tensor)
# Get the pooler output
features = vision_outputs.pooler_output
# Create a dummy gradient for the feature based on the class idx
one_hot = torch.zeros_like(features)
one_hot[0, class_idx] = 1
# Manually backpropagate
features.backward(gradient=one_hot)
# Check for None values
if self.gradients is None or self.activations is None:
print("Warning: Gradients or activations are None. Using fallback CAM.")
return np.ones((14, 14), dtype=np.float32) * 0.5
# Process gradients and activations
if len(self.gradients.shape) == 4: # Expected shape for convolutional layers
gradients = self.gradients.cpu().detach().numpy()
activations = self.activations.cpu().detach().numpy()
weights = np.mean(gradients, axis=(2, 3))
cam = np.zeros(activations.shape[2:], dtype=np.float32)
for i, w in enumerate(weights[0]):
cam += w * activations[0, i, :, :]
else:
# Handle transformer model format
gradients = self.gradients.cpu().detach().numpy()
activations = self.activations.cpu().detach().numpy()
if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
seq_len = activations.shape[1]
# CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
if seq_len == 197:
# Skip the class token (first token) and reshape the patch tokens into a square
patch_tokens = activations[0, 1:, :] # Remove the class token
# Take the mean across the hidden dimension
token_importance = np.mean(np.abs(patch_tokens), axis=1)
# Reshape to the expected grid size (14×14 for CLIP ViT-B/16)
cam = token_importance.reshape(14, 14)
else:
# Try to find factors close to a square
side_len = int(np.sqrt(seq_len))
# Use the mean across features as importance
token_importance = np.mean(np.abs(activations[0]), axis=1)
# Create as square-like shape as possible
cam = np.zeros((side_len, side_len))
# Fill the cam with available values
flat_cam = cam.flatten()
flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
cam = flat_cam.reshape(side_len, side_len)
else:
# Fallback
print("Using fallback CAM shape (14x14)")
cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
# Ensure we have valid values
if cam is None or cam.size == 0:
print("Warning: Generated CAM is empty. Using fallback.")
cam = np.ones((14, 14), dtype=np.float32) * 0.5
cam = np.maximum(cam, 0)
if np.max(cam) > 0:
cam = cam / np.max(cam)
return cam
except Exception as e:
print(f"Error in GradCAM.generate: {str(e)}")
return np.ones((14, 14), dtype=np.float32) * 0.5
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
if face_box is not None:
x, y, w, h = face_box
# Create a mask for the entire image (all zeros initially)
img_np = np.array(image)
full_h, full_w = img_np.shape[:2]
full_cam = np.zeros((full_h, full_w), dtype=np.float32)
# Resize CAM to match face region
face_cam = cv2.resize(cam, (w, h))
# Copy the face CAM into the full image CAM at the face position
full_cam[y:y+h, x:x+w] = face_cam
# Convert full CAM to image
cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8))
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap
cam_colormap = (cam_colormap * 255).astype(np.uint8)
else:
cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR)
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap
cam_colormap = (cam_colormap * 255).astype(np.uint8)
blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha)
return blended
def save_comparison(image, cam, overlay, face_box=None):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original Image
axes[0].imshow(image)
axes[0].set_title("Original")
if face_box is not None:
x, y, w, h = face_box
rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False)
axes[0].add_patch(rect)
axes[0].axis("off")
# CAM
if face_box is not None:
# Create a full image CAM that highlights only the face
img_np = np.array(image)
h, w = img_np.shape[:2]
full_cam = np.zeros((h, w))
x, y, fw, fh = face_box
# Resize CAM to face size
face_cam = cv2.resize(cam, (fw, fh))
# Place it in the right position
full_cam[y:y+fh, x:x+fw] = face_cam
axes[1].imshow(full_cam, cmap="jet")
else:
axes[1].imshow(cam, cmap="jet")
axes[1].set_title("CAM")
axes[1].axis("off")
# Overlay
axes[2].imshow(overlay)
axes[2].set_title("Overlay")
axes[2].axis("off")
plt.tight_layout()
# Convert plot to PIL Image for Streamlit display
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
return Image.open(buf)
def load_clip_model():
# Modified to load checkpoint from Hugging Face
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model_dict = model.state_dict()
checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape}
model_dict.update(checkpoint)
model.load_state_dict(model_dict)
model.eval()
return model, processor
def get_target_layer_clip(model):
# For CLIP ViT large, use a layer that will have activations in the right format
return "vision_model.encoder.layers.23"
def process_images(dataloader, model, cam_extractor, device, pred_class):
# Modified to process a single image and return results for Streamlit
for batch in dataloader:
input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch
original_image = original_images[0]
face_box = face_boxes[0]
print(f"Processing uploaded image...")
# Move tensors and model to device
input_tensor = input_tensor.to(device)
model = model.to(device)
try:
# Forward pass and Grad-CAM generation
output = model.vision_model(pixel_values=input_tensor).pooler_output
class_idx = pred_class # Use predicted class from app.py
cam = cam_extractor.generate(input_tensor, class_idx)
# Generate CAM image
if face_box is not None:
x, y, w, h = face_box
img_np = np.array(original_image)
h_full, w_full = img_np.shape[:2]
full_cam = np.zeros((h_full, w_full))
face_cam = cv2.resize(cam, (w, h))
full_cam[y:y+h, x:x+w] = face_cam
cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8))
else:
cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
cam_colormap = (cam_colormap * 255).astype(np.uint8)
cam_img = Image.fromarray(cam_colormap)
# Generate Overlay
overlay = overlay_cam_on_image(original_image, cam, face_box)
# Generate Comparison
comparison = save_comparison(original_image, cam, overlay, face_box)
return cam, cam_img, overlay, comparison
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
# Return default values in case of error
default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR)
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3]
cam_colormap = (cam_colormap * 255).astype(np.uint8)
cam_img = Image.fromarray(cam_colormap)
overlay = overlay_cam_on_image(original_image, default_cam, face_box)
comparison = save_comparison(original_image, default_cam, overlay, face_box)
return default_cam, cam_img, overlay, comparison