Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from torch import Tensor | |
| import spaces | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from matplotlib import cm | |
| from huggingface_hub import hf_hub_download | |
| from warnings import warn | |
| from models import get_model | |
| mean = (0.485, 0.456, 0.406) | |
| std = (0.229, 0.224, 0.225) | |
| alpha = 0.8 | |
| EPS = 1e-8 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| loaded_model = None | |
| pretrained_models = [ | |
| "ZIP-B @ ShanghaiTech A", "ZIP-B @ ShanghaiTech B", "ZIP-B @ UCF-QNRF", "ZIP-B @ NWPU-Crowd", | |
| "ZIP-S @ ShanghaiTech A", "ZIP-S @ ShanghaiTech B", "ZIP-S @ UCF-QNRF", | |
| "ZIP-T @ ShanghaiTech A", "ZIP-T @ ShanghaiTech B", "ZIP-T @ UCF-QNRF", | |
| "ZIP-N @ ShanghaiTech A", "ZIP-N @ ShanghaiTech B", "ZIP-N @ UCF-QNRF", | |
| "ZIP-P @ ShanghaiTech A", "ZIP-P @ ShanghaiTech B", "ZIP-P @ UCF-QNRF" | |
| ] | |
| # ----------------------------- | |
| # Define the model architecture | |
| # ----------------------------- | |
| def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"): | |
| """ Load the model weights from the Hugging Face Hub.""" | |
| # global loaded_model | |
| # Build model | |
| model_info_path = hf_hub_download( | |
| repo_id=f"Yiming-M/{variant}", | |
| filename=f"checkpoints/{dataset}/best_{metric}.pth", | |
| ) | |
| model = get_model(model_info_path=model_info_path) | |
| model.eval() | |
| loaded_model = model | |
| def _calc_size( | |
| img_w: int, | |
| img_h: int, | |
| min_size: int, | |
| max_size: int, | |
| base: int = 32 | |
| ): | |
| """ | |
| This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size). | |
| Args: | |
| img_w (int): The width of the image. | |
| img_h (int): The height of the image. | |
| min_size (int): The minimum size of the edges of the image. | |
| max_size (int): The maximum size of the edges of the image. | |
| # base (int): The base number to which the new size should be a multiple of. | |
| """ | |
| assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}" | |
| if max_size != float("inf"): | |
| assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided" | |
| assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})" | |
| aspect_ratios = (img_w / img_h, img_h / img_w) | |
| if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio | |
| if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize | |
| ratio = 1. | |
| elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size | |
| ratio = min_size / min(img_w, img_h) | |
| else: # larger than the maximum size, resize to the maximum size | |
| ratio = max_size / max(img_w, img_h) | |
| new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base) | |
| new_w = max(min_size, min(max_size, new_w)) | |
| new_h = max(min_size, min(max_size, new_h)) | |
| return new_w, new_h | |
| else: # impossible to resize and preserve the aspect ratio | |
| msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size." | |
| warn(msg) | |
| return _calc_size(img_w, img_h, min_size, float("inf"), base) | |
| # ----------------------------- | |
| # Preprocessing function | |
| # ----------------------------- | |
| # Adjust the image transforms to match what your model expects. | |
| def transform(image: Image.Image, dataset_name: str) -> Tensor: | |
| assert isinstance(image, Image.Image), "Input must be a PIL Image" | |
| image_tensor = TF.to_tensor(image) | |
| if dataset_name == "sha": | |
| min_size = 448 | |
| max_size = float("inf") | |
| elif dataset_name == "shb": | |
| min_size = 448 | |
| max_size = float("inf") | |
| elif dataset_name == "qnrf": | |
| min_size = 448 | |
| max_size = 2048 | |
| elif dataset_name == "nwpu": | |
| min_size = 448 | |
| max_size = 3072 | |
| image_height, image_width = image_tensor.shape[-2:] | |
| new_width, new_height = _calc_size( | |
| img_w=image_width, | |
| img_h=image_height, | |
| min_size=min_size, | |
| max_size=max_size, | |
| base=32 | |
| ) | |
| if new_height != image_height or new_width != image_width: | |
| image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True) | |
| image_tensor = TF.normalize(image_tensor, mean=mean, std=std) | |
| return image_tensor.unsqueeze(0) # Add batch dimension | |
| def _sliding_window_predict( | |
| model: nn.Module, | |
| image: Tensor, | |
| window_size: int, | |
| stride: int, | |
| max_num_windows: int = 256 | |
| ): | |
| assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}" | |
| window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size | |
| stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride | |
| window_size = tuple(window_size) | |
| stride = tuple(stride) | |
| assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}" | |
| assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}" | |
| assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}" | |
| image_height, image_width = image.shape[-2:] | |
| window_height, window_width = window_size | |
| assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}" | |
| stride_height, stride_width = stride | |
| num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) | |
| num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) | |
| if hasattr(model, "block_size"): | |
| block_size = model.block_size | |
| elif hasattr(model, "module") and hasattr(model.module, "block_size"): | |
| block_size = model.module.block_size | |
| else: | |
| raise ValueError("Model must have block_size attribute") | |
| assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}" | |
| windows = [] | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| x_start, y_start = i * stride_height, j * stride_width | |
| x_end, y_end = x_start + window_height, y_start + window_width | |
| if x_end > image_height: | |
| x_start, x_end = image_height - window_height, image_height | |
| if y_end > image_width: | |
| y_start, y_end = image_width - window_width, image_width | |
| window = image[:, :, x_start:x_end, y_start:y_end] | |
| windows.append(window) | |
| windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w) | |
| model.eval() | |
| pi_maps, lambda_maps = [], [] | |
| for i in range(0, len(windows), max_num_windows): | |
| with torch.no_grad(): | |
| image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))]) | |
| pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats) | |
| pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats | |
| pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp() | |
| pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
| lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
| pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
| lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
| lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
| pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) | |
| pi_maps.append(pi_map.cpu().numpy()) | |
| lambda_maps.append(lambda_map.cpu().numpy()) | |
| # assemble the density map | |
| pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W) | |
| lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W) | |
| assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}" | |
| pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
| lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
| count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
| idx = 0 | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| x_start, y_start = i * stride_height, j * stride_width | |
| x_end, y_end = x_start + window_height, y_start + window_width | |
| if x_end > image_height: | |
| x_start, x_end = image_height - window_height, image_height | |
| if y_end > image_width: | |
| y_start, y_end = image_width - window_width, image_width | |
| pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :] | |
| lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :] | |
| count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1. | |
| idx += 1 | |
| # average the density map | |
| pi_map /= count_map | |
| lambda_map /= count_map | |
| # convert to Tensor and reshape | |
| pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) | |
| lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) | |
| return pi_map, lambda_map | |
| # ----------------------------- | |
| # Inference function | |
| # ----------------------------- | |
| def predict(image: Image.Image, variant_dataset: str, metric: str): | |
| """ | |
| Given an input image, preprocess it, run the model to obtain a density map, | |
| compute the total crowd count, and prepare the density map for display. | |
| """ | |
| # global loaded_model | |
| variant, dataset = variant_dataset.split(" @ ") | |
| if dataset == "ShanghaiTech A": | |
| dataset_name = "sha" | |
| elif dataset == "ShanghaiTech B": | |
| dataset_name = "shb" | |
| elif dataset == "UCF-QNRF": | |
| dataset_name = "qnrf" | |
| elif dataset == "NWPU-Crowd": | |
| dataset_name = "nwpu" | |
| if loaded_model is None: | |
| load_model(variant=variant, dataset=dataset_name, metric=metric) | |
| if not hasattr(loaded_model, "input_size"): | |
| if dataset_name == "sha": | |
| loaded_model.input_size = 224 | |
| elif dataset_name == "shb": | |
| loaded_model.input_size = 448 | |
| elif dataset_name == "qnrf": | |
| loaded_model.input_size = 672 | |
| elif dataset_name == "nwpu": | |
| loaded_model.input_size = 672 | |
| loaded_model.to(device) | |
| # Preprocess the image | |
| input_width, input_height = image.size | |
| image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W) | |
| input_size = loaded_model.input_size | |
| image_height, image_width = image_tensor.shape[-2:] | |
| aspect_ratio = image_width / image_height | |
| if image_height < input_size: | |
| new_height = input_size | |
| new_width = int(new_height * aspect_ratio) | |
| image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) | |
| image_height, image_width = new_height, new_width | |
| if image_width < input_size: | |
| new_width = input_size | |
| new_height = int(new_width / aspect_ratio) | |
| image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) | |
| image_height, image_width = new_height, new_width | |
| with torch.no_grad(): | |
| if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction | |
| # For ViT models with VPT | |
| pi_map, lambda_map = _sliding_window_predict( | |
| model=loaded_model, | |
| image=image_tensor, | |
| window_size=input_size, | |
| stride=input_size | |
| ) | |
| elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models | |
| image_feats = loaded_model.backbone(image_tensor) | |
| # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats) | |
| pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats | |
| pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp() | |
| pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
| lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
| pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
| lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
| lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
| pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) | |
| else: # For non-CLIP models | |
| x = loaded_model.backbone(image_tensor) | |
| logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W) | |
| logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W) | |
| lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W) | |
| pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W) | |
| den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W) | |
| count = den_map.sum().item() | |
| strucrual_zero_map = F.interpolate( | |
| pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
| ).cpu().squeeze().numpy() | |
| lambda_map = F.interpolate( | |
| lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
| ).cpu().squeeze().numpy() | |
| den_map = F.interpolate( | |
| den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
| ).cpu().squeeze().numpy() | |
| sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map) | |
| complete_zero_map = strucrual_zero_map + sampling_zero_map | |
| # Normalize maps for display purposes | |
| def normalize_map(x: np.ndarray) -> np.ndarray: | |
| """ Normalize the map to [0, 1] range for visualization. """ | |
| x_min = np.min(x) | |
| x_max = np.max(x) | |
| if x_max - x_min < EPS: | |
| return np.zeros_like(x) | |
| return (x - x_min) / (x_max - x_min + EPS) | |
| strucrual_zero_map = normalize_map(strucrual_zero_map) | |
| sampling_zero_map = normalize_map(sampling_zero_map) | |
| lambda_map = normalize_map(lambda_map) | |
| den_map = normalize_map(den_map) | |
| complete_zero_map = normalize_map(complete_zero_map) | |
| # Apply a colormap (e.g., 'jet') to get an RGBA image | |
| colormap = cm.get_cmap("jet") | |
| # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. | |
| den_map = (colormap(den_map) * 255).astype(np.uint8) | |
| strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8) | |
| sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8) | |
| lambda_map = (colormap(lambda_map) * 255).astype(np.uint8) | |
| complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8) | |
| # Convert to PIL images | |
| den_map = Image.fromarray(den_map).convert("RGBA") | |
| strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA") | |
| sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA") | |
| lambda_map = Image.fromarray(lambda_map).convert("RGBA") | |
| complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA") | |
| # Ensure the original image is in RGBA format. | |
| image_rgba = image.convert("RGBA") | |
| den_map = Image.blend(image_rgba, den_map, alpha=alpha) | |
| strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha) | |
| sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha) | |
| lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha) | |
| complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha) | |
| return image, den_map, lambda_map, round(count, 2), strucrual_zero_map, sampling_zero_map, complete_zero_map | |
| # ----------------------------- | |
| # Build Gradio Interface using Blocks for a two-column layout | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Crowd Counting by ZIP") | |
| gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Dropdown for model variant | |
| model_dropdown = gr.Dropdown( | |
| choices=pretrained_models, | |
| value="ZIP-B @ NWPU-Crowd", | |
| label="Select a pretrained model" | |
| ) | |
| # Dropdown for metric, always the same choices | |
| metric_dropdown = gr.Dropdown( | |
| choices=["mae", "rmse", "nae"], | |
| value="mae", | |
| label="Select Best Metric" | |
| ) | |
| input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") | |
| submit_btn = gr.Button("Predict") | |
| with gr.Column(): | |
| output_den_map = gr.Image(label="Predicted Density Map", type="pil") | |
| output_lambda_map = gr.Image(label="Lambda Map", type="pil") | |
| output_text = gr.Textbox(label="Predicted Count") | |
| with gr.Column(): | |
| output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil") | |
| output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil") | |
| output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil") | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_img, model_dropdown, metric_dropdown], | |
| outputs=[input_img, output_den_map, output_lambda_map, output_text, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["example1.jpg"], | |
| ["example2.jpg"], | |
| ["example3.jpg"], | |
| ["example4.jpg"], | |
| ["example5.jpg"], | |
| ["example6.jpg"], | |
| ["example7.jpg"], | |
| ["example8.jpg"], | |
| ["example9.jpg"], | |
| ["example10.jpg"], | |
| ["example11.jpg"], | |
| ["example12.jpg"] | |
| ], | |
| inputs=input_img, | |
| label="Try an example" | |
| ) | |
| demo.launch() |