Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from torch import Tensor | |
import spaces | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from matplotlib import cm | |
from huggingface_hub import hf_hub_download | |
from warnings import warn | |
from models import get_model | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
alpha = 0.8 | |
EPS = 1e-8 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pretrained_datasets = { | |
"ZIP-B": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF" "NWPU-Crowd"], | |
"ZIP-S": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], | |
"ZIP-T": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], | |
"ZIP-N": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], | |
"ZIP-P": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"], | |
} | |
# ----------------------------- | |
# Define the model architecture | |
# ----------------------------- | |
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"): | |
""" Load the model weights from the Hugging Face Hub.""" | |
global loaded_model | |
# Build model | |
model_info_path = hf_hub_download( | |
repo_id=f"Yiming-M/{variant}", | |
filename=f"checkpoints/{dataset}/best_{metric}.pth", | |
) | |
model = get_model(model_info_path=model_info_path) | |
model.eval() | |
loaded_model = model | |
def _calc_size( | |
img_w: int, | |
img_h: int, | |
min_size: int, | |
max_size: int, | |
base: int = 32 | |
): | |
""" | |
This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size). | |
Args: | |
img_w (int): The width of the image. | |
img_h (int): The height of the image. | |
min_size (int): The minimum size of the edges of the image. | |
max_size (int): The maximum size of the edges of the image. | |
# base (int): The base number to which the new size should be a multiple of. | |
""" | |
assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}" | |
if max_size != float("inf"): | |
assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided" | |
assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})" | |
aspect_ratios = (img_w / img_h, img_h / img_w) | |
if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio | |
if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize | |
ratio = 1. | |
elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size | |
ratio = min_size / min(img_w, img_h) | |
else: # larger than the maximum size, resize to the maximum size | |
ratio = max_size / max(img_w, img_h) | |
new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base) | |
new_w = max(min_size, min(max_size, new_w)) | |
new_h = max(min_size, min(max_size, new_h)) | |
return new_w, new_h | |
else: # impossible to resize and preserve the aspect ratio | |
msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size." | |
warn(msg) | |
return _calc_size(img_w, img_h, min_size, float("inf"), base) | |
# ----------------------------- | |
# Preprocessing function | |
# ----------------------------- | |
# Adjust the image transforms to match what your model expects. | |
def transform(image: Image.Image, dataset_name: str) -> Tensor: | |
assert isinstance(image, Image.Image), "Input must be a PIL Image" | |
image_tensor = TF.to_tensor(image) | |
if dataset_name == "sha": | |
min_size = 448 | |
max_size = float("inf") | |
elif dataset_name == "shb": | |
min_size = 448 | |
max_size = float("inf") | |
elif dataset_name == "qnrf": | |
min_size = 448 | |
max_size = 2048 | |
elif dataset_name == "nwpu": | |
min_size = 448 | |
max_size = 3072 | |
image_height, image_width = image_tensor.shape[-2:] | |
new_width, new_height = _calc_size( | |
img_w=image_width, | |
img_h=image_height, | |
min_size=min_size, | |
max_size=max_size, | |
base=32 | |
) | |
if new_height != image_height or new_width != image_width: | |
image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True) | |
image_tensor = TF.normalize(image_tensor, mean=mean, std=std) | |
return image_tensor.unsqueeze(0) # Add batch dimension | |
def _sliding_window_predict( | |
model: nn.Module, | |
image: Tensor, | |
window_size: int, | |
stride: int, | |
max_num_windows: int = 256 | |
): | |
assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}" | |
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size | |
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride | |
window_size = tuple(window_size) | |
stride = tuple(stride) | |
assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}" | |
assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}" | |
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}" | |
image_height, image_width = image.shape[-2:] | |
window_height, window_width = window_size | |
assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}" | |
stride_height, stride_width = stride | |
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) | |
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) | |
if hasattr(model, "block_size"): | |
block_size = model.block_size | |
elif hasattr(model, "module") and hasattr(model.module, "block_size"): | |
block_size = model.module.block_size | |
else: | |
raise ValueError("Model must have block_size attribute") | |
assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}" | |
windows = [] | |
for i in range(num_rows): | |
for j in range(num_cols): | |
x_start, y_start = i * stride_height, j * stride_width | |
x_end, y_end = x_start + window_height, y_start + window_width | |
if x_end > image_height: | |
x_start, x_end = image_height - window_height, image_height | |
if y_end > image_width: | |
y_start, y_end = image_width - window_width, image_width | |
window = image[:, :, x_start:x_end, y_start:y_end] | |
windows.append(window) | |
windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w) | |
model.eval() | |
pi_maps, lambda_maps = [], [] | |
for i in range(0, len(windows), max_num_windows): | |
with torch.no_grad(): | |
image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))]) | |
pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats) | |
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats | |
pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp() | |
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) | |
pi_maps.append(pi_map.cpu().numpy()) | |
lambda_maps.append(lambda_map.cpu().numpy()) | |
# assemble the density map | |
pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W) | |
lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W) | |
assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}" | |
pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32) | |
idx = 0 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
x_start, y_start = i * stride_height, j * stride_width | |
x_end, y_end = x_start + window_height, y_start + window_width | |
if x_end > image_height: | |
x_start, x_end = image_height - window_height, image_height | |
if y_end > image_width: | |
y_start, y_end = image_width - window_width, image_width | |
pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :] | |
lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :] | |
count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1. | |
idx += 1 | |
# average the density map | |
pi_map /= count_map | |
lambda_map /= count_map | |
# convert to Tensor and reshape | |
pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) | |
lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size) | |
return pi_map, lambda_map | |
# ----------------------------- | |
# Inference function | |
# ----------------------------- | |
def predict(image: Image.Image, variant: str, dataset: str, metric: str): | |
""" | |
Given an input image, preprocess it, run the model to obtain a density map, | |
compute the total crowd count, and prepare the density map for display. | |
""" | |
global loaded_model | |
if loaded_model is None: | |
if dataset == "ShanghaiTech A": | |
dataset_name = "sha" | |
elif dataset == "ShanghaiTech B": | |
dataset_name = "shb" | |
elif dataset == "UCF-QNRF": | |
dataset_name = "qnrf" | |
elif dataset == "NWPU-Crowd": | |
dataset_name = "nwpu" | |
weight_path = f"Yiming-M/{variant}/checkpoints/{dataset_name}/best_{metric}.pth" | |
load_model(weight_path) | |
loaded_model.to(device) | |
# Preprocess the image | |
input_width, input_height = image.size | |
image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W) | |
input_size = loaded_model.input_size | |
image_height, image_width = image_tensor.shape[-2:] | |
aspect_ratio = image_width / image_height | |
if image_height < input_size: | |
new_height = input_size | |
new_width = int(new_height * aspect_ratio) | |
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) | |
image_height, image_width = new_height, new_width | |
if image_width < input_size: | |
new_width = input_size | |
new_height = int(new_width / aspect_ratio) | |
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True) | |
image_height, image_width = new_height, new_width | |
with torch.no_grad(): | |
if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction | |
# For ViT models with VPT | |
pi_map, lambda_map = _sliding_window_predict( | |
model=loaded_model, | |
image=image_tensor, | |
window_size=input_size, | |
stride=input_size | |
) | |
elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models | |
image_feats = loaded_model.backbone(image_tensor) | |
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats) | |
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats | |
pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp() | |
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W) | |
else: # For non-CLIP models | |
x = loaded_model.backbone(image_tensor) | |
logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W) | |
logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W) | |
lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W) | |
pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W) | |
den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W) | |
count = den_map.sum().item() | |
strucrual_zero_map = F.interpolate( | |
pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
).cpu().squeeze().numpy() | |
lambda_map = F.interpolate( | |
lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
).cpu().squeeze().numpy() | |
den_map = F.interpolate( | |
den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True | |
).cpu().squeeze().numpy() | |
sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map) | |
complete_zero_map = strucrual_zero_map + sampling_zero_map | |
# Normalize maps for display purposes | |
def normalize_map(x: np.ndarray) -> np.ndarray: | |
""" Normalize the map to [0, 1] range for visualization. """ | |
x_min = np.min(x) | |
x_max = np.max(x) | |
if x_max - x_min < EPS: | |
return np.zeros_like(x) | |
return (x - x_min) / (x_max - x_min + EPS) | |
strucrual_zero_map = normalize_map(strucrual_zero_map) | |
sampling_zero_map = normalize_map(sampling_zero_map) | |
lambda_map = normalize_map(lambda_map) | |
den_map = normalize_map(den_map) | |
complete_zero_map = normalize_map(complete_zero_map) | |
# Apply a colormap (e.g., 'jet') to get an RGBA image | |
colormap = cm.get_cmap("jet") | |
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8. | |
den_map = (colormap(den_map) * 255).astype(np.uint8) | |
strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8) | |
sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8) | |
lambda_map = (colormap(lambda_map) * 255).astype(np.uint8) | |
complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8) | |
# Convert to PIL images | |
den_map = Image.fromarray(den_map).convert("RGBA") | |
strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA") | |
sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA") | |
lambda_map = Image.fromarray(lambda_map).convert("RGBA") | |
complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA") | |
# Ensure the original image is in RGBA format. | |
image_rgba = image.convert("RGBA") | |
den_map = Image.blend(image_rgba, den_map, alpha=alpha) | |
strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha) | |
sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha) | |
lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha) | |
complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha) | |
return image, strucrual_zero_map, sampling_zero_map, complete_zero_map, lambda_map, den_map, f"Predicted Count: {count:.2f}" | |
# ----------------------------- | |
# Build Gradio Interface using Blocks for a two-column layout | |
# ----------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Crowd Counting by ZIP") | |
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") | |
with gr.Row(): | |
with gr.Column(): | |
# Dropdown for model variant | |
variant_dropdown = gr.Dropdown( | |
choices=list(pretrained_datasets.keys()), | |
value="ZIP-B", | |
label="Select Model Variant" | |
) | |
# Dropdown for pretrained dataset, dynamically updated based on variant | |
dataset_dropdown = gr.Dropdown( | |
choices=pretrained_datasets["ZIP-B"], | |
value=pretrained_datasets["ZIP-B"][0], | |
label="Select Pretrained Dataset" | |
) | |
# Dropdown for metric, always the same choices | |
metric_dropdown = gr.Dropdown( | |
choices=["mae", "rmse", "nae"], | |
value="mae", | |
label="Select Best Metric" | |
) | |
# Update dataset choices when variant changes | |
def update_dataset(variant): | |
choices = pretrained_datasets[variant] | |
return gr.Dropdown.update( | |
choices=choices, | |
value=choices[0] | |
) | |
variant_dropdown.change( | |
fn=update_dataset, | |
inputs=variant_dropdown, | |
outputs=dataset_dropdown | |
) | |
input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil") | |
submit_btn = gr.Button("Predict") | |
with gr.Column(): | |
output_den_map = gr.Image(label="Predicted Density Map", type="pil") | |
output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil") | |
output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil") | |
output_lambda_map = gr.Image(label="Lambda Map", type="pil") | |
output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil") | |
output_text = gr.Textbox(label="Total Count") | |
submit_btn.click( | |
fn=predict, | |
inputs=[input_img, variant_dropdown, dataset_dropdown, metric_dropdown], | |
outputs=[input_img, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map, output_lambda_map, output_den_map, output_text] | |
) | |
gr.Examples( | |
examples=[ | |
["example1.jpg"], | |
["example2.jpg"], | |
["example3.jpg"], | |
["example4.jpg"], | |
["example5.jpg"], | |
["example6.jpg"], | |
["example7.jpg"], | |
["example8.jpg"], | |
["example9.jpg"], | |
["example10.jpg"], | |
["example11.jpg"], | |
["example12.jpg"] | |
], | |
inputs=input_img, | |
label="Try an example" | |
) | |
demo.launch() |