import torch from torch import nn import torch.nn.functional as F import torchvision.transforms.functional as TF from torch import Tensor import spaces import numpy as np from PIL import Image import gradio as gr from matplotlib import cm from huggingface_hub import hf_hub_download from warnings import warn from models import get_model mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) alpha = 0.8 EPS = 1e-8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pretrained_datasets = { "ZIP-B": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF" "NWPU-Crowd"], "ZIP-S": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], "ZIP-T": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], "ZIP-N": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], "ZIP-P": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], } # ----------------------------- # Define the model architecture # ----------------------------- def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"): """ Load the model weights from the Hugging Face Hub.""" global loaded_model # Build model model_info_path = hf_hub_download( repo_id=f"Yiming-M/{variant}", filename=f"checkpoints/{dataset}/best_{metric}.pth", ) model = get_model(model_info_path=model_info_path) model.eval() loaded_model = model def _calc_size( img_w: int, img_h: int, min_size: int, max_size: int, base: int = 32 ): """ This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size). Args: img_w (int): The width of the image. img_h (int): The height of the image. min_size (int): The minimum size of the edges of the image. max_size (int): The maximum size of the edges of the image. # base (int): The base number to which the new size should be a multiple of. """ assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}" if max_size != float("inf"): assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided" assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})" aspect_ratios = (img_w / img_h, img_h / img_w) if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize ratio = 1. elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size ratio = min_size / min(img_w, img_h) else: # larger than the maximum size, resize to the maximum size ratio = max_size / max(img_w, img_h) new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base) new_w = max(min_size, min(max_size, new_w)) new_h = max(min_size, min(max_size, new_h)) return new_w, new_h else: # impossible to resize and preserve the aspect ratio msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size." warn(msg) return _calc_size(img_w, img_h, min_size, float("inf"), base) # ----------------------------- # Preprocessing function # ----------------------------- # Adjust the image transforms to match what your model expects. def transform(image: Image.Image, dataset_name: str) -> Tensor: assert isinstance(image, Image.Image), "Input must be a PIL Image" image_tensor = TF.to_tensor(image) if dataset_name == "sha": min_size = 448 max_size = float("inf") elif dataset_name == "shb": min_size = 448 max_size = float("inf") elif dataset_name == "qnrf": min_size = 448 max_size = 2048 elif dataset_name == "nwpu": min_size = 448 max_size = 3072 image_height, image_width = image_tensor.shape[-2:] new_width, new_height = _calc_size( img_w=image_width, img_h=image_height, min_size=min_size, max_size=max_size, base=32 ) if new_height != image_height or new_width != image_width: image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True) image_tensor = TF.normalize(image_tensor, mean=mean, std=std) return image_tensor.unsqueeze(0) # Add batch dimension def _sliding_window_predict( model: nn.Module, image: Tensor, window_size: int, stride: int, max_num_windows: int = 256 ): assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}" window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride window_size = tuple(window_size) stride = tuple(stride) assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}" assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}" assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}" image_height, image_width = image.shape[-2:] window_height, window_width = window_size assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}" stride_height, stride_width = stride num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) if hasattr(model, "block_size"): block_size = model.block_size elif hasattr(model, "module") and hasattr(model.module, "block_size"): block_size = model.module.block_size else: raise ValueError("Model must have block_size attribute") assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}" windows = [] for i in range(num_rows): for j in range(num_cols): x_start, y_start = i * stride_height, j * stride_width x_end, y_end = x_start + window_height, y_start + window_width if x_end > image_height: x_start, x_end = image_height - window_height, image_height if y_end > image_width: y_start, y_end = image_width - window_width, image_width window = image[:, :, x_start:x_end, y_start:y_end] windows.append(window) windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w) model.eval() pi_maps, lambda_maps = [], [] for i in range(0, len(windows), max_num_windows): with torch.no_grad(): image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))]) pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats) pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp() pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) pi_maps.append(pi_map.cpu().numpy()) lambda_maps.append(lambda_map.cpu().numpy()) # assemble the density map pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W) lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W) assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}" pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) idx = 0 for i in range(num_rows): for j in range(num_cols): x_start, y_start = i * stride_height, j * stride_width x_end, y_end = x_start + window_height, y_start + window_width if x_end > image_height: x_start, x_end = image_height - window_height, image_height if y_end > image_width: y_start, y_end = image_width - window_width, image_width pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :] lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :] count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1. idx += 1 # average the density map pi_map /= count_map lambda_map /= count_map # convert to Tensor and reshape pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) return pi_map, lambda_map # ----------------------------- # Inference function # ----------------------------- @spaces.GPU(duration=120) def predict(image: Image.Image, variant: str, dataset: str, metric: str): """ Given an input image, preprocess it, run the model to obtain a density map, compute the total crowd count, and prepare the density map for display. """ global loaded_model if loaded_model is None: if dataset == "ShanghaiTech A": dataset_name = "sha" elif dataset == "ShanghaiTech B": dataset_name = "shb" elif dataset == "UCF-QNRF": dataset_name = "qnrf" elif dataset == "NWPU-Crowd": dataset_name = "nwpu" weight_path = f"Yiming-M/{variant}/checkpoints/{dataset_name}/best_{metric}.pth" load_model(weight_path) loaded_model.to(device) # Preprocess the image input_width, input_height = image.size image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W) input_size = loaded_model.input_size image_height, image_width = image_tensor.shape[-2:] aspect_ratio = image_width / image_height if image_height < input_size: new_height = input_size new_width = int(new_height * aspect_ratio) image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) image_height, image_width = new_height, new_width if image_width < input_size: new_width = input_size new_height = int(new_width / aspect_ratio) image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) image_height, image_width = new_height, new_width with torch.no_grad(): if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction # For ViT models with VPT pi_map, lambda_map = _sliding_window_predict( model=loaded_model, image=image_tensor, window_size=input_size, stride=input_size ) elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models image_feats = loaded_model.backbone(image_tensor) # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats) pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp() pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) else: # For non-CLIP models x = loaded_model.backbone(image_tensor) logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W) logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W) lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W) pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W) den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W) count = den_map.sum().item() strucrual_zero_map = F.interpolate( pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True ).cpu().squeeze().numpy() lambda_map = F.interpolate( lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True ).cpu().squeeze().numpy() den_map = F.interpolate( den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True ).cpu().squeeze().numpy() sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map) complete_zero_map = strucrual_zero_map + sampling_zero_map # Normalize maps for display purposes def normalize_map(x: np.ndarray) -> np.ndarray: """ Normalize the map to [0, 1] range for visualization. """ x_min = np.min(x) x_max = np.max(x) if x_max - x_min < EPS: return np.zeros_like(x) return (x - x_min) / (x_max - x_min + EPS) strucrual_zero_map = normalize_map(strucrual_zero_map) sampling_zero_map = normalize_map(sampling_zero_map) lambda_map = normalize_map(lambda_map) den_map = normalize_map(den_map) complete_zero_map = normalize_map(complete_zero_map) # Apply a colormap (e.g., 'jet') to get an RGBA image colormap = cm.get_cmap("jet") # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. den_map = (colormap(den_map) * 255).astype(np.uint8) strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8) sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8) lambda_map = (colormap(lambda_map) * 255).astype(np.uint8) complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8) # Convert to PIL images den_map = Image.fromarray(den_map).convert("RGBA") strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA") sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA") lambda_map = Image.fromarray(lambda_map).convert("RGBA") complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA") # Ensure the original image is in RGBA format. image_rgba = image.convert("RGBA") den_map = Image.blend(image_rgba, den_map, alpha=alpha) strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha) sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha) lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha) complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha) return image, strucrual_zero_map, sampling_zero_map, complete_zero_map, lambda_map, den_map, f"Predicted Count: {count:.2f}" # ----------------------------- # Build Gradio Interface using Blocks for a two-column layout # ----------------------------- with gr.Blocks() as demo: gr.Markdown("# Crowd Counting by ZIP") gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") with gr.Row(): with gr.Column(): # Dropdown for model variant variant_dropdown = gr.Dropdown( choices=list(pretrained_datasets.keys()), value="ZIP-B", label="Select Model Variant" ) # Dropdown for pretrained dataset, dynamically updated based on variant dataset_dropdown = gr.Dropdown( choices=pretrained_datasets["ZIP-B"], value=pretrained_datasets["ZIP-B"][0], label="Select Pretrained Dataset" ) # Dropdown for metric, always the same choices metric_dropdown = gr.Dropdown( choices=["mae", "rmse", "nae"], value="mae", label="Select Best Metric" ) # Update dataset choices when variant changes def update_dataset(variant): choices = pretrained_datasets[variant] return gr.Dropdown.update( choices=choices, value=choices[0] ) variant_dropdown.change( fn=update_dataset, inputs=variant_dropdown, outputs=dataset_dropdown ) input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") submit_btn = gr.Button("Predict") with gr.Column(): output_den_map = gr.Image(label="Predicted Density Map", type="pil") output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil") output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil") output_lambda_map = gr.Image(label="Lambda Map", type="pil") output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil") output_text = gr.Textbox(label="Total Count") submit_btn.click( fn=predict, inputs=[input_img, variant_dropdown, dataset_dropdown, metric_dropdown], outputs=[input_img, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map, output_lambda_map, output_den_map, output_text] ) gr.Examples( examples=[ ["example1.jpg"], ["example2.jpg"], ["example3.jpg"], ["example4.jpg"], ["example5.jpg"], ["example6.jpg"], ["example7.jpg"], ["example8.jpg"], ["example9.jpg"], ["example10.jpg"], ["example11.jpg"], ["example12.jpg"] ], inputs=input_img, label="Try an example" ) demo.launch()