|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import to_pil_image |
|
import matplotlib.pyplot as plt |
|
from torch.utils.data import DataLoader, Dataset |
|
from PIL import Image |
|
import os |
|
import numpy as np |
|
import warnings |
|
from transformers import AutoProcessor, CLIPModel |
|
import cv2 |
|
import re |
|
from huggingface_hub import hf_hub_download |
|
import io |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, image, transform=None, face_only=True, dataset_name=None): |
|
|
|
self.image = image |
|
self.transform = transform |
|
self.face_only = face_only |
|
self.dataset_name = dataset_name |
|
|
|
self.face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def detect_face(self, image_np): |
|
"""Detect face in image and return the face region""" |
|
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) |
|
faces = self.face_detector.detectMultiScale(gray, 1.1, 5) |
|
|
|
|
|
if len(faces) == 0: |
|
print("No face detected, using whole image") |
|
h, w = image_np.shape[:2] |
|
return (0, 0, w, h), image_np |
|
|
|
|
|
if len(faces) > 1: |
|
|
|
areas = [w*h for (x, y, w, h) in faces] |
|
largest_idx = np.argmax(areas) |
|
x, y, w, h = faces[largest_idx] |
|
else: |
|
x, y, w, h = faces[0] |
|
|
|
|
|
padding_x = int(w * 0.05) |
|
padding_y = int(h * 0.05) |
|
|
|
|
|
x1 = max(0, x - padding_x) |
|
y1 = max(0, y - padding_y) |
|
x2 = min(image_np.shape[1], x + w + padding_x) |
|
y2 = min(image_np.shape[0], y + h + padding_y) |
|
|
|
|
|
face_img = image_np[y1:y2, x1:x2] |
|
|
|
return (x1, y1, x2-x1, y2-y1), face_img |
|
|
|
def __getitem__(self, idx): |
|
|
|
image_np = np.array(self.image) |
|
label = 0 |
|
|
|
|
|
original_image = self.image.copy() |
|
|
|
|
|
if self.face_only: |
|
face_box, face_img_np = self.detect_face(image_np) |
|
face_img = Image.fromarray(face_img_np) |
|
|
|
|
|
if self.transform: |
|
face_tensor = self.transform(face_img) |
|
else: |
|
face_tensor = transforms.ToTensor()(face_img) |
|
|
|
return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name |
|
else: |
|
|
|
if self.transform: |
|
image_tensor = self.transform(self.image) |
|
else: |
|
image_tensor = transforms.ToTensor()(self.image) |
|
|
|
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name |
|
|
|
class GradCAM: |
|
def __init__(self, model, target_layer): |
|
self.model = model |
|
self.target_layer = target_layer |
|
self.gradients = None |
|
self.activations = None |
|
self._register_hooks() |
|
|
|
def _register_hooks(self): |
|
def forward_hook(module, input, output): |
|
if isinstance(output, tuple): |
|
self.activations = output[0] |
|
else: |
|
self.activations = output |
|
|
|
def backward_hook(module, grad_in, grad_out): |
|
if isinstance(grad_out, tuple): |
|
self.gradients = grad_out[0] |
|
else: |
|
self.gradients = grad_out |
|
|
|
layer = dict([*self.model.named_modules()])[self.target_layer] |
|
layer.register_forward_hook(forward_hook) |
|
layer.register_backward_hook(backward_hook) |
|
|
|
def generate(self, input_tensor, class_idx): |
|
self.model.zero_grad() |
|
|
|
try: |
|
|
|
vision_outputs = self.model.vision_model(pixel_values=input_tensor) |
|
|
|
|
|
features = vision_outputs.pooler_output |
|
|
|
|
|
one_hot = torch.zeros_like(features) |
|
one_hot[0, class_idx] = 1 |
|
|
|
|
|
features.backward(gradient=one_hot) |
|
|
|
|
|
if self.gradients is None or self.activations is None: |
|
print("Warning: Gradients or activations are None. Using fallback CAM.") |
|
return np.ones((14, 14), dtype=np.float32) * 0.5 |
|
|
|
|
|
if len(self.gradients.shape) == 4: |
|
gradients = self.gradients.cpu().detach().numpy() |
|
activations = self.activations.cpu().detach().numpy() |
|
|
|
weights = np.mean(gradients, axis=(2, 3)) |
|
cam = np.zeros(activations.shape[2:], dtype=np.float32) |
|
|
|
for i, w in enumerate(weights[0]): |
|
cam += w * activations[0, i, :, :] |
|
else: |
|
|
|
gradients = self.gradients.cpu().detach().numpy() |
|
activations = self.activations.cpu().detach().numpy() |
|
|
|
if len(activations.shape) == 3: |
|
seq_len = activations.shape[1] |
|
|
|
|
|
if seq_len == 197: |
|
|
|
patch_tokens = activations[0, 1:, :] |
|
|
|
token_importance = np.mean(np.abs(patch_tokens), axis=1) |
|
|
|
cam = token_importance.reshape(14, 14) |
|
else: |
|
|
|
side_len = int(np.sqrt(seq_len)) |
|
|
|
token_importance = np.mean(np.abs(activations[0]), axis=1) |
|
|
|
cam = np.zeros((side_len, side_len)) |
|
|
|
flat_cam = cam.flatten() |
|
flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))] |
|
cam = flat_cam.reshape(side_len, side_len) |
|
else: |
|
|
|
print("Using fallback CAM shape (14x14)") |
|
cam = np.ones((14, 14), dtype=np.float32) * 0.5 |
|
|
|
|
|
if cam is None or cam.size == 0: |
|
print("Warning: Generated CAM is empty. Using fallback.") |
|
cam = np.ones((14, 14), dtype=np.float32) * 0.5 |
|
|
|
cam = np.maximum(cam, 0) |
|
if np.max(cam) > 0: |
|
cam = cam / np.max(cam) |
|
|
|
return cam |
|
|
|
except Exception as e: |
|
print(f"Error in GradCAM.generate: {str(e)}") |
|
return np.ones((14, 14), dtype=np.float32) * 0.5 |
|
|
|
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5): |
|
if face_box is not None: |
|
x, y, w, h = face_box |
|
|
|
img_np = np.array(image) |
|
full_h, full_w = img_np.shape[:2] |
|
full_cam = np.zeros((full_h, full_w), dtype=np.float32) |
|
|
|
|
|
face_cam = cv2.resize(cam, (w, h)) |
|
|
|
|
|
full_cam[y:y+h, x:x+w] = face_cam |
|
|
|
|
|
cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8)) |
|
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] |
|
cam_colormap = (cam_colormap * 255).astype(np.uint8) |
|
else: |
|
cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(image.size, Image.BILINEAR) |
|
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] |
|
cam_colormap = (cam_colormap * 255).astype(np.uint8) |
|
|
|
blended = Image.blend(image, Image.fromarray(cam_colormap), alpha=alpha) |
|
return blended |
|
|
|
def save_comparison(image, cam, overlay, face_box=None): |
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
|
|
|
axes[0].imshow(image) |
|
axes[0].set_title("Original") |
|
if face_box is not None: |
|
x, y, w, h = face_box |
|
rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False) |
|
axes[0].add_patch(rect) |
|
axes[0].axis("off") |
|
|
|
|
|
if face_box is not None: |
|
|
|
img_np = np.array(image) |
|
h, w = img_np.shape[:2] |
|
full_cam = np.zeros((h, w)) |
|
|
|
x, y, fw, fh = face_box |
|
|
|
face_cam = cv2.resize(cam, (fw, fh)) |
|
|
|
full_cam[y:y+fh, x:x+fw] = face_cam |
|
axes[1].imshow(full_cam, cmap="jet") |
|
else: |
|
axes[1].imshow(cam, cmap="jet") |
|
axes[1].set_title("CAM") |
|
axes[1].axis("off") |
|
|
|
|
|
axes[2].imshow(overlay) |
|
axes[2].set_title("Overlay") |
|
axes[2].axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight") |
|
plt.close() |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
def load_clip_model(): |
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
checkpoint_path = hf_hub_download(repo_id="drg31/model", filename="model.pth") |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
model_dict = model.state_dict() |
|
checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape} |
|
|
|
model_dict.update(checkpoint) |
|
model.load_state_dict(model_dict) |
|
|
|
model.eval() |
|
return model, processor |
|
|
|
def get_target_layer_clip(model): |
|
|
|
return "vision_model.encoder.layers.23" |
|
|
|
def process_images(dataloader, model, cam_extractor, device, pred_class): |
|
|
|
for batch in dataloader: |
|
input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch |
|
original_image = original_images[0] |
|
face_box = face_boxes[0] |
|
|
|
print(f"Processing uploaded image...") |
|
|
|
|
|
input_tensor = input_tensor.to(device) |
|
model = model.to(device) |
|
|
|
try: |
|
|
|
output = model.vision_model(pixel_values=input_tensor).pooler_output |
|
class_idx = pred_class |
|
cam = cam_extractor.generate(input_tensor, class_idx) |
|
|
|
|
|
if face_box is not None: |
|
x, y, w, h = face_box |
|
img_np = np.array(original_image) |
|
h_full, w_full = img_np.shape[:2] |
|
full_cam = np.zeros((h_full, w_full)) |
|
face_cam = cv2.resize(cam, (w, h)) |
|
full_cam[y:y+h, x:x+w] = face_cam |
|
cam_img = Image.fromarray((plt.cm.jet(full_cam)[:, :, :3] * 255).astype(np.uint8)) |
|
else: |
|
cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) |
|
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] |
|
cam_colormap = (cam_colormap * 255).astype(np.uint8) |
|
cam_img = Image.fromarray(cam_colormap) |
|
|
|
|
|
overlay = overlay_cam_on_image(original_image, cam, face_box) |
|
|
|
|
|
comparison = save_comparison(original_image, cam, overlay, face_box) |
|
|
|
return cam, cam_img, overlay, comparison |
|
|
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
default_cam = np.ones((14, 14), dtype=np.float32) * 0.5 |
|
cam_resized = Image.fromarray((default_cam * 255).astype(np.uint8)).resize(original_image.size, Image.BILINEAR) |
|
cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] |
|
cam_colormap = (cam_colormap * 255).astype(np.uint8) |
|
cam_img = Image.fromarray(cam_colormap) |
|
overlay = overlay_cam_on_image(original_image, default_cam, face_box) |
|
comparison = save_comparison(original_image, default_cam, overlay, face_box) |
|
return default_cam, cam_img, overlay, comparison |