Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import torch | |
from torchvision.transforms.functional import normalize, to_pil_image | |
from torchvision.transforms import ToTensor, Normalize | |
import matplotlib.pyplot as plt | |
import json | |
from models import get_model | |
from utils import resize_density_map, sliding_window_predict | |
from PIL import Image | |
import numpy as np | |
from scipy.ndimage import gaussian_filter | |
from sklearn.cluster import KMeans | |
import datetime | |
from typing import Optional | |
from typing import Union | |
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
sys.path.append(project_root) | |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class ClipEBC: | |
""" | |
CLIP-EBC (Efficient Boundary Counting) ์ด๋ฏธ์ง ์ฒ๋ฆฌ ํด๋์ค์ ๋๋ค. | |
CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์ฒ๋ฆฌํ๋ฉฐ, ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์์ธก ๊ธฐ๋ฅ์ ํฌํจํ | |
๋ค์ํ ์ค์ ์ต์ ์ ์ ๊ณตํฉ๋๋ค. | |
Attributes: | |
truncation (int): ์๋ผ๋ด๊ธฐ ๋งค๊ฐ๋ณ์. ๊ธฐ๋ณธ๊ฐ 4. | |
reduction (int): ์ถ์ ๋น์จ. ๊ธฐ๋ณธ๊ฐ 8. | |
granularity (str): ์ธ๋ถํ ์์ค. ๊ธฐ๋ณธ๊ฐ "fine". | |
anchor_points (str): ์ต์ปค ํฌ์ธํธ ๋ฐฉ๋ฒ. ๊ธฐ๋ณธ๊ฐ "average". | |
model_name (str): CLIP ๋ชจ๋ธ ์ด๋ฆ. ๊ธฐ๋ณธ๊ฐ "clip_vit_b_16". | |
input_size (int): ์ ๋ ฅ ์ด๋ฏธ์ง ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ 224. | |
window_size (int): ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ 224. | |
stride (int): ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ด๋ ๊ฐ๊ฒฉ. ๊ธฐ๋ณธ๊ฐ 224. | |
prompt_type (str): ํ๋กฌํํธ ์ ํ. ๊ธฐ๋ณธ๊ฐ "word". | |
dataset_name (str): ๋ฐ์ดํฐ์ ์ด๋ฆ. ๊ธฐ๋ณธ๊ฐ "qnrf". | |
num_vpt (int): ๋น์ฃผ์ผ ํ๋กฌํํธ ํ ํฐ ์. ๊ธฐ๋ณธ๊ฐ 32. | |
vpt_drop (float): ๋น์ฃผ์ผ ํ๋กฌํํธ ํ ํฐ ๋๋กญ์์ ๋น์จ. ๊ธฐ๋ณธ๊ฐ 0.0. | |
deep_vpt (bool): ๊น์ ๋น์ฃผ์ผ ํ๋กฌํํธ ํ ํฐ ์ฌ์ฉ ์ฌ๋ถ. ๊ธฐ๋ณธ๊ฐ True. | |
mean (tuple): ์ ๊ทํ๋ฅผ ์ํ ํ๊ท ๊ฐ. ๊ธฐ๋ณธ๊ฐ (0.485, 0.456, 0.406). | |
std (tuple): ์ ๊ทํ๋ฅผ ์ํ ํ์คํธ์ฐจ๊ฐ. ๊ธฐ๋ณธ๊ฐ (0.229, 0.224, 0.225). | |
""" | |
def __init__(self, | |
truncation=4, | |
reduction=8, | |
granularity="fine", | |
anchor_points="average", | |
model_name="clip_vit_b_16", | |
input_size=224, | |
window_size=224, | |
stride=224, | |
prompt_type="word", | |
dataset_name="qnrf", | |
num_vpt=32, | |
vpt_drop=0., | |
deep_vpt=True, | |
mean=(0.485, 0.456, 0.406), | |
std=(0.229, 0.224, 0.225), | |
config_dir="configs"): | |
"""CLIPEBC ํด๋์ค๋ฅผ ์ค์ ๋งค๊ฐ๋ณ์์ ํจ๊ป ์ด๊ธฐํํฉ๋๋ค.""" | |
self.truncation = truncation | |
self.reduction = reduction | |
self.granularity = granularity | |
self.anchor_points_type = anchor_points # ์๋ ์ ๋ ฅ๊ฐ ์ ์ฅ | |
self.model_name = model_name | |
self.input_size = input_size | |
self.window_size = window_size | |
self.stride = stride | |
self.prompt_type = prompt_type | |
self.dataset_name = dataset_name | |
self.num_vpt = num_vpt | |
self.vpt_drop = vpt_drop | |
self.deep_vpt = deep_vpt | |
self.mean = mean | |
self.std = std | |
self.config_dir = config_dir | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.bins = None | |
self.anchor_points = None | |
self.model = None | |
# ์ด๊ธฐ ์ค์ ๋ก๋ ๋ฐ ๋ชจ๋ธ ์ด๊ธฐํ | |
self._load_config() | |
self._initialize_model() | |
def _load_config(self): | |
"""์ค์ ํ์ผ์ ๋ก๋ํ๊ณ bins์ anchor_points๋ฅผ ์ค์ ํฉ๋๋ค.""" | |
config_path = os.path.join(self.config_dir, f"reduction_{self.reduction}.json") | |
with open(config_path, "r") as f: | |
config = json.load(f)[str(self.truncation)][self.dataset_name] | |
self.bins = config["bins"][self.granularity] | |
self.bins = [(float(b[0]), float(b[1])) for b in self.bins] | |
if self.anchor_points_type == "average": | |
self.anchor_points = config["anchor_points"][self.granularity]["average"] | |
else: | |
self.anchor_points = config["anchor_points"][self.granularity]["middle"] | |
self.anchor_points = [float(p) for p in self.anchor_points] | |
def _initialize_model(self): | |
"""CLIP ๋ชจ๋ธ์ ์ด๊ธฐํํฉ๋๋ค.""" | |
self.model = get_model( | |
backbone=self.model_name, | |
input_size=self.input_size, | |
reduction=self.reduction, | |
bins=self.bins, | |
anchor_points=self.anchor_points, | |
prompt_type=self.prompt_type, | |
num_vpt=self.num_vpt, | |
vpt_drop=self.vpt_drop, | |
deep_vpt=self.deep_vpt | |
) | |
ckpt_path = "assets/CLIP_EBC_nwpu_rmse.pth" | |
ckpt = torch.load(ckpt_path, map_location=device) | |
self.model.load_state_dict(ckpt) | |
self.model = self.model.to(device) | |
self.model.eval() | |
def visualize_density_map(self, alpha: float = 0.5, save: bool = False, | |
save_path: Optional[str] = None): | |
""" | |
ํ์ฌ ์ ์ฅ๋ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํํฉ๋๋ค. | |
Args: | |
alpha (float): density map์ ํฌ๋ช ๋ (0~1). ๊ธฐ๋ณธ๊ฐ 0.5 | |
save (bool): ์๊ฐํ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง๋ก ์ ์ฅํ ์ง ์ฌ๋ถ. ๊ธฐ๋ณธ๊ฐ False | |
save_path (str, optional): ์ ์ฅํ ๊ฒฝ๋ก. None์ผ ๊ฒฝ์ฐ ํ์ฌ ๋๋ ํ ๋ฆฌ์ ์๋ ์์ฑ๋ ์ด๋ฆ์ผ๋ก ์ ์ฅ. | |
๊ธฐ๋ณธ๊ฐ None | |
Returns: | |
Tuple[matplotlib.figure.Figure, np.ndarray]: | |
- density map์ด ์ค๋ฒ๋ ์ด๋ matplotlib Figure ๊ฐ์ฒด | |
- RGB ํ์์ ์๊ฐํ๋ ์ด๋ฏธ์ง ๋ฐฐ์ด (H, W, 3) | |
Raises: | |
ValueError: density_map ๋๋ processed_image๊ฐ None์ธ ๊ฒฝ์ฐ (predict ๋ฉ์๋๊ฐ ์คํ๋์ง ์์ ๊ฒฝ์ฐ) | |
""" | |
if self.density_map is None or self.processed_image is None: | |
raise ValueError("๋จผ์ predict ๋ฉ์๋๋ฅผ ์คํํ์ฌ ์์ธก์ ์ํํด์ผ ํฉ๋๋ค.") | |
fig, ax = plt.subplots(dpi=200, frameon=False) | |
ax.imshow(self.processed_image) | |
ax.imshow(self.density_map, cmap="jet", alpha=alpha) | |
ax.axis("off") | |
if save: | |
if save_path is None: | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
save_path = f"crowd_density_{timestamp}.png" | |
# ์ฌ๋ฐฑ ์ ๊ฑฐํ๊ณ ์ ์ฅ | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) | |
print(f"Image saved to: {save_path}") | |
fig.canvas.draw() | |
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
image_from_plot = image_from_plot[:,:,:3] # RGB๋ก ๋ณํ | |
return fig , image_from_plot | |
def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97, | |
save: bool = False, save_path: Optional[str] = None): | |
""" | |
์์ธก๋ ๊ตฐ์ค ์์น๋ฅผ ์ ์ผ๋ก ํ์ํ์ฌ ์๊ฐํํฉ๋๋ค. | |
Args: | |
dot_size (int): ์ ์ ํฌ๊ธฐ. ๊ธฐ๋ณธ๊ฐ 20 | |
sigma (float): Gaussian ํํฐ์ sigma ๊ฐ. ๊ธฐ๋ณธ๊ฐ 1 | |
percentile (float): ์๊ณ๊ฐ์ผ๋ก ์ฌ์ฉํ ๋ฐฑ๋ถ์์ (0-100). ๊ธฐ๋ณธ๊ฐ 97 | |
save (bool): ์๊ฐํ ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง๋ก ์ ์ฅํ ์ง ์ฌ๋ถ. ๊ธฐ๋ณธ๊ฐ False | |
save_path (str, optional): ์ ์ฅํ ๊ฒฝ๋ก. None์ผ ๊ฒฝ์ฐ ํ์ฌ ๋๋ ํ ๋ฆฌ์ ์๋ ์์ฑ๋ ์ด๋ฆ์ผ๋ก ์ ์ฅ. | |
๊ธฐ๋ณธ๊ฐ None | |
Returns: | |
Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]: | |
- matplotlib figure์ canvas ๊ฐ์ฒด | |
- RGB ํ์์ ์๊ฐํ๋ ์ด๋ฏธ์ง ๋ฐฐ์ด (H, W, 3) | |
Raises: | |
ValueError: density_map ๋๋ processed_image๊ฐ None์ธ ๊ฒฝ์ฐ (predict ๋ฉ์๋๊ฐ ์คํ๋์ง ์์ ๊ฒฝ์ฐ) | |
""" | |
if self.density_map is None or self.processed_image is None: | |
raise ValueError("๋จผ์ predict ๋ฉ์๋๋ฅผ ์คํํ์ฌ ์์ธก์ ์ํํด์ผ ํฉ๋๋ค.") | |
adjusted_pred_count = int(round(self.count)) | |
fig, ax = plt.subplots(dpi=200, frameon=False) | |
ax.imshow(self.processed_image) | |
filtered_density = gaussian_filter(self.density_map, sigma=sigma) | |
threshold = np.percentile(filtered_density, percentile) | |
candidate_pixels = np.column_stack(np.where(filtered_density >= threshold)) | |
if len(candidate_pixels) > adjusted_pred_count: | |
kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10) | |
kmeans.fit(candidate_pixels) | |
head_positions = kmeans.cluster_centers_.astype(int) | |
else: | |
head_positions = candidate_pixels | |
y_coords, x_coords = head_positions[:, 0], head_positions[:, 1] | |
ax.scatter(x_coords, y_coords, | |
c='red', | |
s=dot_size, | |
alpha=1.0, | |
edgecolors='white', | |
linewidth=1) | |
ax.axis("off") | |
if save: | |
if save_path is None: | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
save_path = f"crowd_dots_{timestamp}.png" | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) | |
print(f"Image saved to: {save_path}") | |
# Figure๋ฅผ numpy ๋ฐฐ์ด๋ก ๋ณํ | |
fig.canvas.draw() | |
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
image_from_plot = image_from_plot[:,:,:3] # RGB๋ก ๋ณํ | |
# plt.close(fig) | |
# return image_from_plot | |
return fig.canvas, image_from_plot | |
def _process_image(self, image: Union[str, np.ndarray]) -> torch.Tensor: | |
""" | |
์ด๋ฏธ์ง๋ฅผ ์ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ฏธ์ง ๊ฒฝ๋ก, ๋ํ์ด ๋ฐฐ์ด, Streamlit UploadedFile ๋ชจ๋ ์ฒ๋ฆฌ ๊ฐ๋ฅํฉ๋๋ค. | |
Args: | |
image: ์ ๋ ฅ ์ด๋ฏธ์ง. ๋ค์ ํ์ ์ค ํ๋์ฌ์ผ ํฉ๋๋ค: | |
- str: ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก | |
- np.ndarray: (H, W, 3) ํํ์ RGB ์ด๋ฏธ์ง | |
- UploadedFile: Streamlit์ ์ ๋ก๋๋ ํ์ผ | |
Returns: | |
torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ ์, shape (1, 3, H, W) | |
Raises: | |
ValueError: ์ง์ํ์ง ์๋ ์ด๋ฏธ์ง ํ์์ด ์ ๋ ฅ๋ ๊ฒฝ์ฐ | |
Exception: ์ด๋ฏธ์ง ํ์ผ์ ์ด ์ ์๋ ๊ฒฝ์ฐ | |
""" | |
to_tensor = ToTensor() | |
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
# ์๋ณธ ์ด๋ฏธ์ง ์ ์ฅ | |
self.original_image = image | |
# ์ ๋ ฅ ํ์ ์ ๋ฐ๋ฅธ ์ฒ๋ฆฌ | |
if isinstance(image, str): | |
# ํ์ผ ๊ฒฝ๋ก์ธ ๊ฒฝ์ฐ | |
with open(image, "rb") as f: | |
pil_image = Image.open(f).convert("RGB") | |
elif isinstance(image, np.ndarray): | |
# ๋ํ์ด ๋ฐฐ์ด์ธ ๊ฒฝ์ฐ | |
if image.dtype == np.uint8: | |
pil_image = Image.fromarray(image) | |
else: | |
# float ํ์ ์ธ ๊ฒฝ์ฐ [0, 1] ๋ฒ์๋ก ๊ฐ์ ํ๊ณ ๋ณํ | |
pil_image = Image.fromarray((image * 255).astype(np.uint8)) | |
else: | |
# Streamlit UploadedFile ๋๋ ๊ธฐํ ํ์ผ ๊ฐ์ฒด์ธ ๊ฒฝ์ฐ | |
try: | |
pil_image = Image.open(image).convert("RGB") | |
except Exception as e: | |
raise ValueError(f"์ง์ํ์ง ์๋ ์ด๋ฏธ์ง ํ์์ ๋๋ค: {type(image)}") from e | |
# ํ ์ ๋ณํ ๋ฐ ์ ๊ทํ | |
tensor_image = to_tensor(pil_image) | |
normalized_image = normalize(tensor_image) | |
batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W) | |
batched_image = batched_image.to(self.device) | |
return batched_image | |
def _post_process_image(self, image): | |
"""์ด๋ฏธ์ง ํ์ฒ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค.""" | |
image = normalize(image, mean=(0., 0., 0.), | |
std=(1. / self.std[0], 1. / self.std[1], 1. / self.std[2])) | |
image = normalize(image, mean=(-self.mean[0], -self.mean[1], -self.mean[2]), | |
std=(1., 1., 1.)) | |
processed_image = to_pil_image(image.squeeze(0)) | |
return processed_image | |
def predict(self, image: torch.Tensor) -> Image.Image: | |
""" | |
๋ชจ๋ธ ์ถ๋ ฅ ์ด๋ฏธ์ง์ ํ์ฒ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค. | |
Args: | |
image (torch.Tensor): ํ์ฒ๋ฆฌํ ์ด๋ฏธ์ง ํ ์, shape (1, 3, H, W) | |
Returns: | |
PIL.Image.Image: ํ์ฒ๋ฆฌ๋ PIL ์ด๋ฏธ์ง | |
Note: | |
์ด๋ฏธ์ง ํ ์์ ๋ํด ์ ๊ทํ๋ฅผ ์ญ๋ณํํ๊ณ PIL ์ด๋ฏธ์ง ํ์์ผ๋ก ๋ณํํฉ๋๋ค. | |
self.mean๊ณผ self.std ๊ฐ์ ์ฌ์ฉํ์ฌ ์๋ณธ ์ด๋ฏธ์ง์ ์ค์ผ์ผ๋ก ๋ณต์ํฉ๋๋ค. | |
""" | |
processed_image = self._process_image(image) | |
image_height, image_width = processed_image.shape[-2:] | |
processed_image = processed_image.to(self.device) | |
pred_density = sliding_window_predict(self.model, processed_image, | |
self.window_size, self.stride) | |
pred_count = pred_density.sum().item() | |
resized_pred_density = resize_density_map(pred_density, | |
(image_height, image_width)).cpu() | |
self.processed_image = self._post_process_image(processed_image) | |
self.density_map = resized_pred_density.squeeze().numpy() | |
self.count = pred_count | |
return pred_count | |
def crowd_count(self): | |
""" | |
๊ฐ์ฅ ์ต๊ทผ ์์ธก์ ๊ตฐ์ค ์๋ฅผ ๋ฐํํฉ๋๋ค. | |
Returns: | |
float: ์์ธก๋ ๊ตฐ์ค ์ | |
None: ์์ง ์์ธก์ด ์ํ๋์ง ์์ ๊ฒฝ์ฐ | |
""" | |
return self.count | |
def get_density_map(self): | |
""" | |
๊ฐ์ฅ ์ต๊ทผ ์์ธก์ ๋ฐ๋ ๋งต์ ๋ฐํํฉ๋๋ค. | |
Returns: | |
numpy.ndarray: ๋ฐ๋ ๋งต | |
None: ์์ง ์์ธก์ด ์ํ๋์ง ์์ ๊ฒฝ์ฐ | |
""" | |
return self.density_map | |