test_ebc / custom /clip_ebc.py
piaspace's picture
[first]
bb3e610
import os
import sys
import torch
from torchvision.transforms.functional import normalize, to_pil_image
from torchvision.transforms import ToTensor, Normalize
import matplotlib.pyplot as plt
import json
from models import get_model
from utils import resize_density_map, sliding_window_predict
from PIL import Image
import numpy as np
from scipy.ndimage import gaussian_filter
from sklearn.cluster import KMeans
import datetime
from typing import Optional
from typing import Union
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ClipEBC:
"""
CLIP-EBC (Efficient Boundary Counting) ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.
CLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฉฐ, ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ์˜ˆ์ธก ๊ธฐ๋Šฅ์„ ํฌํ•จํ•œ
๋‹ค์–‘ํ•œ ์„ค์ • ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
Attributes:
truncation (int): ์ž˜๋ผ๋‚ด๊ธฐ ๋งค๊ฐœ๋ณ€์ˆ˜. ๊ธฐ๋ณธ๊ฐ’ 4.
reduction (int): ์ถ•์†Œ ๋น„์œจ. ๊ธฐ๋ณธ๊ฐ’ 8.
granularity (str): ์„ธ๋ถ„ํ™” ์ˆ˜์ค€. ๊ธฐ๋ณธ๊ฐ’ "fine".
anchor_points (str): ์•ต์ปค ํฌ์ธํŠธ ๋ฐฉ๋ฒ•. ๊ธฐ๋ณธ๊ฐ’ "average".
model_name (str): CLIP ๋ชจ๋ธ ์ด๋ฆ„. ๊ธฐ๋ณธ๊ฐ’ "clip_vit_b_16".
input_size (int): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ’ 224.
window_size (int): ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ’ 224.
stride (int): ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ์ด๋™ ๊ฐ„๊ฒฉ. ๊ธฐ๋ณธ๊ฐ’ 224.
prompt_type (str): ํ”„๋กฌํ”„ํŠธ ์œ ํ˜•. ๊ธฐ๋ณธ๊ฐ’ "word".
dataset_name (str): ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„. ๊ธฐ๋ณธ๊ฐ’ "qnrf".
num_vpt (int): ๋น„์ฃผ์–ผ ํ”„๋กฌํ”„ํŠธ ํ† ํฐ ์ˆ˜. ๊ธฐ๋ณธ๊ฐ’ 32.
vpt_drop (float): ๋น„์ฃผ์–ผ ํ”„๋กฌํ”„ํŠธ ํ† ํฐ ๋“œ๋กญ์•„์›ƒ ๋น„์œจ. ๊ธฐ๋ณธ๊ฐ’ 0.0.
deep_vpt (bool): ๊นŠ์€ ๋น„์ฃผ์–ผ ํ”„๋กฌํ”„ํŠธ ํ† ํฐ ์‚ฌ์šฉ ์—ฌ๋ถ€. ๊ธฐ๋ณธ๊ฐ’ True.
mean (tuple): ์ •๊ทœํ™”๋ฅผ ์œ„ํ•œ ํ‰๊ท ๊ฐ’. ๊ธฐ๋ณธ๊ฐ’ (0.485, 0.456, 0.406).
std (tuple): ์ •๊ทœํ™”๋ฅผ ์œ„ํ•œ ํ‘œ์ค€ํŽธ์ฐจ๊ฐ’. ๊ธฐ๋ณธ๊ฐ’ (0.229, 0.224, 0.225).
"""
def __init__(self,
truncation=4,
reduction=8,
granularity="fine",
anchor_points="average",
model_name="clip_vit_b_16",
input_size=224,
window_size=224,
stride=224,
prompt_type="word",
dataset_name="qnrf",
num_vpt=32,
vpt_drop=0.,
deep_vpt=True,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
config_dir="configs"):
"""CLIPEBC ํด๋ž˜์Šค๋ฅผ ์„ค์ • ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ํ•จ๊ป˜ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
self.truncation = truncation
self.reduction = reduction
self.granularity = granularity
self.anchor_points_type = anchor_points # ์›๋ž˜ ์ž…๋ ฅ๊ฐ’ ์ €์žฅ
self.model_name = model_name
self.input_size = input_size
self.window_size = window_size
self.stride = stride
self.prompt_type = prompt_type
self.dataset_name = dataset_name
self.num_vpt = num_vpt
self.vpt_drop = vpt_drop
self.deep_vpt = deep_vpt
self.mean = mean
self.std = std
self.config_dir = config_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.bins = None
self.anchor_points = None
self.model = None
# ์ดˆ๊ธฐ ์„ค์ • ๋กœ๋“œ ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
self._load_config()
self._initialize_model()
def _load_config(self):
"""์„ค์ • ํŒŒ์ผ์„ ๋กœ๋“œํ•˜๊ณ  bins์™€ anchor_points๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค."""
config_path = os.path.join(self.config_dir, f"reduction_{self.reduction}.json")
with open(config_path, "r") as f:
config = json.load(f)[str(self.truncation)][self.dataset_name]
self.bins = config["bins"][self.granularity]
self.bins = [(float(b[0]), float(b[1])) for b in self.bins]
if self.anchor_points_type == "average":
self.anchor_points = config["anchor_points"][self.granularity]["average"]
else:
self.anchor_points = config["anchor_points"][self.granularity]["middle"]
self.anchor_points = [float(p) for p in self.anchor_points]
def _initialize_model(self):
"""CLIP ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
self.model = get_model(
backbone=self.model_name,
input_size=self.input_size,
reduction=self.reduction,
bins=self.bins,
anchor_points=self.anchor_points,
prompt_type=self.prompt_type,
num_vpt=self.num_vpt,
vpt_drop=self.vpt_drop,
deep_vpt=self.deep_vpt
)
ckpt_path = "assets/CLIP_EBC_nwpu_rmse.pth"
ckpt = torch.load(ckpt_path, map_location=device)
self.model.load_state_dict(ckpt)
self.model = self.model.to(device)
self.model.eval()
def visualize_density_map(self, alpha: float = 0.5, save: bool = False,
save_path: Optional[str] = None):
"""
ํ˜„์žฌ ์ €์žฅ๋œ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
Args:
alpha (float): density map์˜ ํˆฌ๋ช…๋„ (0~1). ๊ธฐ๋ณธ๊ฐ’ 0.5
save (bool): ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€๋กœ ์ €์žฅํ• ์ง€ ์—ฌ๋ถ€. ๊ธฐ๋ณธ๊ฐ’ False
save_path (str, optional): ์ €์žฅํ•  ๊ฒฝ๋กœ. None์ผ ๊ฒฝ์šฐ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ž๋™ ์ƒ์„ฑ๋œ ์ด๋ฆ„์œผ๋กœ ์ €์žฅ.
๊ธฐ๋ณธ๊ฐ’ None
Returns:
Tuple[matplotlib.figure.Figure, np.ndarray]:
- density map์ด ์˜ค๋ฒ„๋ ˆ์ด๋œ matplotlib Figure ๊ฐ์ฒด
- RGB ํ˜•์‹์˜ ์‹œ๊ฐํ™”๋œ ์ด๋ฏธ์ง€ ๋ฐฐ์—ด (H, W, 3)
Raises:
ValueError: density_map ๋˜๋Š” processed_image๊ฐ€ None์ธ ๊ฒฝ์šฐ (predict ๋ฉ”์„œ๋“œ๊ฐ€ ์‹คํ–‰๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
"""
if self.density_map is None or self.processed_image is None:
raise ValueError("๋จผ์ € predict ๋ฉ”์„œ๋“œ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
fig, ax = plt.subplots(dpi=200, frameon=False)
ax.imshow(self.processed_image)
ax.imshow(self.density_map, cmap="jet", alpha=alpha)
ax.axis("off")
if save:
if save_path is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"crowd_density_{timestamp}.png"
# ์—ฌ๋ฐฑ ์ œ๊ฑฐํ•˜๊ณ  ์ €์žฅ
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
print(f"Image saved to: {save_path}")
fig.canvas.draw()
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
image_from_plot = image_from_plot[:,:,:3] # RGB๋กœ ๋ณ€ํ™˜
return fig , image_from_plot
def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97,
save: bool = False, save_path: Optional[str] = None):
"""
์˜ˆ์ธก๋œ ๊ตฐ์ค‘ ์œ„์น˜๋ฅผ ์ ์œผ๋กœ ํ‘œ์‹œํ•˜์—ฌ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
Args:
dot_size (int): ์ ์˜ ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ’ 20
sigma (float): Gaussian ํ•„ํ„ฐ์˜ sigma ๊ฐ’. ๊ธฐ๋ณธ๊ฐ’ 1
percentile (float): ์ž„๊ณ„๊ฐ’์œผ๋กœ ์‚ฌ์šฉํ•  ๋ฐฑ๋ถ„์œ„์ˆ˜ (0-100). ๊ธฐ๋ณธ๊ฐ’ 97
save (bool): ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€๋กœ ์ €์žฅํ• ์ง€ ์—ฌ๋ถ€. ๊ธฐ๋ณธ๊ฐ’ False
save_path (str, optional): ์ €์žฅํ•  ๊ฒฝ๋กœ. None์ผ ๊ฒฝ์šฐ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ž๋™ ์ƒ์„ฑ๋œ ์ด๋ฆ„์œผ๋กœ ์ €์žฅ.
๊ธฐ๋ณธ๊ฐ’ None
Returns:
Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]:
- matplotlib figure์˜ canvas ๊ฐ์ฒด
- RGB ํ˜•์‹์˜ ์‹œ๊ฐํ™”๋œ ์ด๋ฏธ์ง€ ๋ฐฐ์—ด (H, W, 3)
Raises:
ValueError: density_map ๋˜๋Š” processed_image๊ฐ€ None์ธ ๊ฒฝ์šฐ (predict ๋ฉ”์„œ๋“œ๊ฐ€ ์‹คํ–‰๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
"""
if self.density_map is None or self.processed_image is None:
raise ValueError("๋จผ์ € predict ๋ฉ”์„œ๋“œ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
adjusted_pred_count = int(round(self.count))
fig, ax = plt.subplots(dpi=200, frameon=False)
ax.imshow(self.processed_image)
filtered_density = gaussian_filter(self.density_map, sigma=sigma)
threshold = np.percentile(filtered_density, percentile)
candidate_pixels = np.column_stack(np.where(filtered_density >= threshold))
if len(candidate_pixels) > adjusted_pred_count:
kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10)
kmeans.fit(candidate_pixels)
head_positions = kmeans.cluster_centers_.astype(int)
else:
head_positions = candidate_pixels
y_coords, x_coords = head_positions[:, 0], head_positions[:, 1]
ax.scatter(x_coords, y_coords,
c='red',
s=dot_size,
alpha=1.0,
edgecolors='white',
linewidth=1)
ax.axis("off")
if save:
if save_path is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"crowd_dots_{timestamp}.png"
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
print(f"Image saved to: {save_path}")
# Figure๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
fig.canvas.draw()
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
image_from_plot = image_from_plot[:,:,:3] # RGB๋กœ ๋ณ€ํ™˜
# plt.close(fig)
# return image_from_plot
return fig.canvas, image_from_plot
def _process_image(self, image: Union[str, np.ndarray]) -> torch.Tensor:
"""
์ด๋ฏธ์ง€๋ฅผ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ, ๋„˜ํŒŒ์ด ๋ฐฐ์—ด, Streamlit UploadedFile ๋ชจ๋‘ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.
Args:
image: ์ž…๋ ฅ ์ด๋ฏธ์ง€. ๋‹ค์Œ ํ˜•์‹ ์ค‘ ํ•˜๋‚˜์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค:
- str: ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ
- np.ndarray: (H, W, 3) ํ˜•ํƒœ์˜ RGB ์ด๋ฏธ์ง€
- UploadedFile: Streamlit์˜ ์—…๋กœ๋“œ๋œ ํŒŒ์ผ
Returns:
torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ํ…์„œ, shape (1, 3, H, W)
Raises:
ValueError: ์ง€์›ํ•˜์ง€ ์•Š๋Š” ์ด๋ฏธ์ง€ ํ˜•์‹์ด ์ž…๋ ฅ๋œ ๊ฒฝ์šฐ
Exception: ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์—ด ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ
"""
to_tensor = ToTensor()
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ์›๋ณธ ์ด๋ฏธ์ง€ ์ €์žฅ
self.original_image = image
# ์ž…๋ ฅ ํƒ€์ž…์— ๋”ฐ๋ฅธ ์ฒ˜๋ฆฌ
if isinstance(image, str):
# ํŒŒ์ผ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ
with open(image, "rb") as f:
pil_image = Image.open(f).convert("RGB")
elif isinstance(image, np.ndarray):
# ๋„˜ํŒŒ์ด ๋ฐฐ์—ด์ธ ๊ฒฝ์šฐ
if image.dtype == np.uint8:
pil_image = Image.fromarray(image)
else:
# float ํƒ€์ž…์ธ ๊ฒฝ์šฐ [0, 1] ๋ฒ”์œ„๋กœ ๊ฐ€์ •ํ•˜๊ณ  ๋ณ€ํ™˜
pil_image = Image.fromarray((image * 255).astype(np.uint8))
else:
# Streamlit UploadedFile ๋˜๋Š” ๊ธฐํƒ€ ํŒŒ์ผ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ
try:
pil_image = Image.open(image).convert("RGB")
except Exception as e:
raise ValueError(f"์ง€์›ํ•˜์ง€ ์•Š๋Š” ์ด๋ฏธ์ง€ ํ˜•์‹์ž…๋‹ˆ๋‹ค: {type(image)}") from e
# ํ…์„œ ๋ณ€ํ™˜ ๋ฐ ์ •๊ทœํ™”
tensor_image = to_tensor(pil_image)
normalized_image = normalize(tensor_image)
batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W)
batched_image = batched_image.to(self.device)
return batched_image
def _post_process_image(self, image):
"""์ด๋ฏธ์ง€ ํ›„์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค."""
image = normalize(image, mean=(0., 0., 0.),
std=(1. / self.std[0], 1. / self.std[1], 1. / self.std[2]))
image = normalize(image, mean=(-self.mean[0], -self.mean[1], -self.mean[2]),
std=(1., 1., 1.))
processed_image = to_pil_image(image.squeeze(0))
return processed_image
@torch.no_grad()
def predict(self, image: torch.Tensor) -> Image.Image:
"""
๋ชจ๋ธ ์ถœ๋ ฅ ์ด๋ฏธ์ง€์˜ ํ›„์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
image (torch.Tensor): ํ›„์ฒ˜๋ฆฌํ•  ์ด๋ฏธ์ง€ ํ…์„œ, shape (1, 3, H, W)
Returns:
PIL.Image.Image: ํ›„์ฒ˜๋ฆฌ๋œ PIL ์ด๋ฏธ์ง€
Note:
์ด๋ฏธ์ง€ ํ…์„œ์— ๋Œ€ํ•ด ์ •๊ทœํ™”๋ฅผ ์—ญ๋ณ€ํ™˜ํ•˜๊ณ  PIL ์ด๋ฏธ์ง€ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
self.mean๊ณผ self.std ๊ฐ’์„ ์‚ฌ์šฉํ•˜์—ฌ ์›๋ณธ ์ด๋ฏธ์ง€์˜ ์Šค์ผ€์ผ๋กœ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค.
"""
processed_image = self._process_image(image)
image_height, image_width = processed_image.shape[-2:]
processed_image = processed_image.to(self.device)
pred_density = sliding_window_predict(self.model, processed_image,
self.window_size, self.stride)
pred_count = pred_density.sum().item()
resized_pred_density = resize_density_map(pred_density,
(image_height, image_width)).cpu()
self.processed_image = self._post_process_image(processed_image)
self.density_map = resized_pred_density.squeeze().numpy()
self.count = pred_count
return pred_count
def crowd_count(self):
"""
๊ฐ€์žฅ ์ตœ๊ทผ ์˜ˆ์ธก์˜ ๊ตฐ์ค‘ ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Returns:
float: ์˜ˆ์ธก๋œ ๊ตฐ์ค‘ ์ˆ˜
None: ์•„์ง ์˜ˆ์ธก์ด ์ˆ˜ํ–‰๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
"""
return self.count
def get_density_map(self):
"""
๊ฐ€์žฅ ์ตœ๊ทผ ์˜ˆ์ธก์˜ ๋ฐ€๋„ ๋งต์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Returns:
numpy.ndarray: ๋ฐ€๋„ ๋งต
None: ์•„์ง ์˜ˆ์ธก์ด ์ˆ˜ํ–‰๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
"""
return self.density_map