ZIP / app.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
raw
history blame
20.9 kB
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch import Tensor
import spaces
import numpy as np
from PIL import Image
import gradio as gr
from matplotlib import cm
from huggingface_hub import hf_hub_download
from warnings import warn
from models import get_model
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
EPS = 1e-8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_datasets = {
"ZIP-B": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF" "NWPU-Crowd"],
"ZIP-S": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
"ZIP-T": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
"ZIP-N": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
"ZIP-P": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
}
# -----------------------------
# Define the model architecture
# -----------------------------
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
""" Load the model weights from the Hugging Face Hub."""
global loaded_model
# Build model
model_info_path = hf_hub_download(
repo_id=f"Yiming-M/{variant}",
filename=f"checkpoints/{dataset}/best_{metric}.pth",
)
model = get_model(model_info_path=model_info_path)
model.eval()
loaded_model = model
def _calc_size(
img_w: int,
img_h: int,
min_size: int,
max_size: int,
base: int = 32
):
"""
This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size).
Args:
img_w (int): The width of the image.
img_h (int): The height of the image.
min_size (int): The minimum size of the edges of the image.
max_size (int): The maximum size of the edges of the image.
# base (int): The base number to which the new size should be a multiple of.
"""
assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}"
if max_size != float("inf"):
assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided"
assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})"
aspect_ratios = (img_w / img_h, img_h / img_w)
if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio
if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize
ratio = 1.
elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size
ratio = min_size / min(img_w, img_h)
else: # larger than the maximum size, resize to the maximum size
ratio = max_size / max(img_w, img_h)
new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base)
new_w = max(min_size, min(max_size, new_w))
new_h = max(min_size, min(max_size, new_h))
return new_w, new_h
else: # impossible to resize and preserve the aspect ratio
msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size."
warn(msg)
return _calc_size(img_w, img_h, min_size, float("inf"), base)
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image, dataset_name: str) -> Tensor:
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
if dataset_name == "sha":
min_size = 448
max_size = float("inf")
elif dataset_name == "shb":
min_size = 448
max_size = float("inf")
elif dataset_name == "qnrf":
min_size = 448
max_size = 2048
elif dataset_name == "nwpu":
min_size = 448
max_size = 3072
image_height, image_width = image_tensor.shape[-2:]
new_width, new_height = _calc_size(
img_w=image_width,
img_h=image_height,
min_size=min_size,
max_size=max_size,
base=32
)
if new_height != image_height or new_width != image_width:
image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
def _sliding_window_predict(
model: nn.Module,
image: Tensor,
window_size: int,
stride: int,
max_num_windows: int = 256
):
assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}"
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
window_size = tuple(window_size)
stride = tuple(stride)
assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}"
assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}"
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}"
image_height, image_width = image.shape[-2:]
window_height, window_width = window_size
assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}"
stride_height, stride_width = stride
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
if hasattr(model, "block_size"):
block_size = model.block_size
elif hasattr(model, "module") and hasattr(model.module, "block_size"):
block_size = model.module.block_size
else:
raise ValueError("Model must have block_size attribute")
assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}"
windows = []
for i in range(num_rows):
for j in range(num_cols):
x_start, y_start = i * stride_height, j * stride_width
x_end, y_end = x_start + window_height, y_start + window_width
if x_end > image_height:
x_start, x_end = image_height - window_height, image_height
if y_end > image_width:
y_start, y_end = image_width - window_width, image_width
window = image[:, :, x_start:x_end, y_start:y_end]
windows.append(window)
windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w)
model.eval()
pi_maps, lambda_maps = [], []
for i in range(0, len(windows), max_num_windows):
with torch.no_grad():
image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))])
pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats)
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats
pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp()
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
pi_maps.append(pi_map.cpu().numpy())
lambda_maps.append(lambda_map.cpu().numpy())
# assemble the density map
pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W)
lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W)
assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}"
pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
idx = 0
for i in range(num_rows):
for j in range(num_cols):
x_start, y_start = i * stride_height, j * stride_width
x_end, y_end = x_start + window_height, y_start + window_width
if x_end > image_height:
x_start, x_end = image_height - window_height, image_height
if y_end > image_width:
y_start, y_end = image_width - window_width, image_width
pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :]
lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :]
count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1.
idx += 1
# average the density map
pi_map /= count_map
lambda_map /= count_map
# convert to Tensor and reshape
pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
return pi_map, lambda_map
# -----------------------------
# Inference function
# -----------------------------
@spaces.GPU(duration=120)
def predict(image: Image.Image, variant: str, dataset: str, metric: str):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
global loaded_model
if loaded_model is None:
if dataset == "ShanghaiTech A":
dataset_name = "sha"
elif dataset == "ShanghaiTech B":
dataset_name = "shb"
elif dataset == "UCF-QNRF":
dataset_name = "qnrf"
elif dataset == "NWPU-Crowd":
dataset_name = "nwpu"
weight_path = f"Yiming-M/{variant}/checkpoints/{dataset_name}/best_{metric}.pth"
load_model(weight_path)
loaded_model.to(device)
# Preprocess the image
input_width, input_height = image.size
image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W)
input_size = loaded_model.input_size
image_height, image_width = image_tensor.shape[-2:]
aspect_ratio = image_width / image_height
if image_height < input_size:
new_height = input_size
new_width = int(new_height * aspect_ratio)
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
image_height, image_width = new_height, new_width
if image_width < input_size:
new_width = input_size
new_height = int(new_width / aspect_ratio)
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
image_height, image_width = new_height, new_width
with torch.no_grad():
if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction
# For ViT models with VPT
pi_map, lambda_map = _sliding_window_predict(
model=loaded_model,
image=image_tensor,
window_size=input_size,
stride=input_size
)
elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models
image_feats = loaded_model.backbone(image_tensor)
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats)
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats
pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp()
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
else: # For non-CLIP models
x = loaded_model.backbone(image_tensor)
logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W)
logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W)
lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W)
pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W)
den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W)
count = den_map.sum().item()
strucrual_zero_map = F.interpolate(
pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
).cpu().squeeze().numpy()
lambda_map = F.interpolate(
lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
).cpu().squeeze().numpy()
den_map = F.interpolate(
den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
).cpu().squeeze().numpy()
sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map)
complete_zero_map = strucrual_zero_map + sampling_zero_map
# Normalize maps for display purposes
def normalize_map(x: np.ndarray) -> np.ndarray:
""" Normalize the map to [0, 1] range for visualization. """
x_min = np.min(x)
x_max = np.max(x)
if x_max - x_min < EPS:
return np.zeros_like(x)
return (x - x_min) / (x_max - x_min + EPS)
strucrual_zero_map = normalize_map(strucrual_zero_map)
sampling_zero_map = normalize_map(sampling_zero_map)
lambda_map = normalize_map(lambda_map)
den_map = normalize_map(den_map)
complete_zero_map = normalize_map(complete_zero_map)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
den_map = (colormap(den_map) * 255).astype(np.uint8)
strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8)
sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8)
lambda_map = (colormap(lambda_map) * 255).astype(np.uint8)
complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8)
# Convert to PIL images
den_map = Image.fromarray(den_map).convert("RGBA")
strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA")
sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA")
lambda_map = Image.fromarray(lambda_map).convert("RGBA")
complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
den_map = Image.blend(image_rgba, den_map, alpha=alpha)
strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha)
sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha)
lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha)
complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha)
return image, strucrual_zero_map, sampling_zero_map, complete_zero_map, lambda_map, den_map, f"Predicted Count: {count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting by ZIP")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
# Dropdown for model variant
variant_dropdown = gr.Dropdown(
choices=list(pretrained_datasets.keys()),
value="ZIP-B",
label="Select Model Variant"
)
# Dropdown for pretrained dataset, dynamically updated based on variant
dataset_dropdown = gr.Dropdown(
choices=pretrained_datasets["ZIP-B"],
value=pretrained_datasets["ZIP-B"][0],
label="Select Pretrained Dataset"
)
# Dropdown for metric, always the same choices
metric_dropdown = gr.Dropdown(
choices=["mae", "rmse", "nae"],
value="mae",
label="Select Best Metric"
)
# Update dataset choices when variant changes
def update_dataset(variant):
choices = pretrained_datasets[variant]
return gr.Dropdown.update(
choices=choices,
value=choices[0]
)
variant_dropdown.change(
fn=update_dataset,
inputs=variant_dropdown,
outputs=dataset_dropdown
)
input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
submit_btn = gr.Button("Predict")
with gr.Column():
output_den_map = gr.Image(label="Predicted Density Map", type="pil")
output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil")
output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
output_lambda_map = gr.Image(label="Lambda Map", type="pil")
output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(
fn=predict,
inputs=[input_img, variant_dropdown, dataset_dropdown, metric_dropdown],
outputs=[input_img, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map, output_lambda_map, output_den_map, output_text]
)
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"],
["example3.jpg"],
["example4.jpg"],
["example5.jpg"],
["example6.jpg"],
["example7.jpg"],
["example8.jpg"],
["example9.jpg"],
["example10.jpg"],
["example11.jpg"],
["example12.jpg"]
],
inputs=input_img,
label="Try an example"
)
demo.launch()