diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3230874b13c64f7e7b0d3cf5ff6902d93cec8a32 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from custom.clip_ebc import ClipEBC + +__all__ = ["ClipEBC"] \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e7b12310e503c356e6d11f5951611c0f8e2fe --- /dev/null +++ b/app.py @@ -0,0 +1,60 @@ +import gradio as gr +from custom.clip_ebc_onnx import ClipEBCOnnx +import numpy as np +import matplotlib.pyplot as plt + +# ONNX 모델 초기화 +model = ClipEBCOnnx() + +def predict_crowd(image): + """ + 이미지를 받아서 군중 수를 예측하고 시각화 결과를 반환합니다. + + Args: + image: Gradio에서 받은 이미지 (numpy array) + + Returns: + tuple: (예측된 군중 수, 밀도 맵 시각화, 점 시각화) + """ + count = model.predict(image) + + # 밀도 맵 시각화 + fig_density, density_map = model.visualize_density_map() + plt.close(fig_density) # 메모리 누수 방지 + # 점 시각화 + canvas, dot_map = model.visualize_dots() + plt.close(canvas.figure) + + return ( + f"예측된 군중 수: {count:.1f}명", + density_map, + dot_map + ) + +with gr.Blocks(title="CLIP-EBC Crowd Counter") as app: + gr.Markdown("# CLIP-EBC Crowd Counter") + gr.Markdown("이미지를 업로드하여 군중 수를 예측하고 시각화합니다.") + + with gr.Row(): + input_image = gr.Image(type="numpy", label="입력 이미지") + + with gr.Row(): + predict_btn = gr.Button("예측", variant="primary") + + with gr.Row(): + count_text = gr.Textbox(label="예측 결과") + + with gr.Row(): + with gr.Column(): + density_output = gr.Image(label="밀도 맵") + with gr.Column(): + dots_output = gr.Image(label="점 시각화") + + predict_btn.click( + fn=predict_crowd, + inputs=input_image, + outputs=[count_text, density_output, dots_output] + ) + +if __name__ == "__main__": + app.launch(share=False) \ No newline at end of file diff --git a/assets/__init__.py b/assets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32269b3f86c1e61d08d959bcfba45156c6e1101d --- /dev/null +++ b/assets/__init__.py @@ -0,0 +1,35 @@ +from huggingface_hub import hf_hub_download +import os + +def download_required_files(): + """Initialize required files from Hugging Face Hub""" + try: + cache_dir = "assets/" + if not os.path.exists(os.path.join(cache_dir, "CLIP_EBC_nwpu_rmse_onnx.onnx")): + hf_hub_download( + repo_id="PIA-SPACE-LAB/CLIP_EBC_nwpu_rmse_onnx", + filename="CLIP_EBC_nwpu_rmse_onnx.onnx", + # cache_dir=cache_dir, + local_dir=cache_dir + ) + print("Required files downloaded successfully") + except Exception as e: + print(f"Error downloading required files: {e}") + +def download_required_files2(): + """Initialize required files from Hugging Face Hub""" + try: + cache_dir = "assets/" + if not os.path.exists(os.path.join(cache_dir, "CLIP_EBC_nwpu_rmse.pth")): + hf_hub_download( + repo_id="PIA-SPACE-LAB/CLIP_EBC_nwpu_rmse", + filename="CLIP_EBC_nwpu_rmse.pth", + # cache_dir=cache_dir, + local_dir=cache_dir + ) + print("Required files downloaded successfully") + except Exception as e: + print(f"Error downloading required files: {e}") + +download_required_files() +download_required_files2() \ No newline at end of file diff --git a/configs/reduction_16.json b/configs/reduction_16.json new file mode 100644 index 0000000000000000000000000000000000000000..58e094d13d2b55cbf4cc1a94c163c8a284b863b0 --- /dev/null +++ b/configs/reduction_16.json @@ -0,0 +1,33 @@ +{ + "8":{ + "qnrf": { + "bins": { + "fine":[ + [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], + [5, 5], [6, 6], [7, 7], [8, "inf"] + ], + "dynamic": [ + [0, 0], [1, 1], [2, 2], [3, 3], + [4, 5], [6, 7], [8, "inf"] + ], + "coarse": [ + [0, 0], [1, 2], [3, 4], [5, 6], [7, "inf"] + ] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "average": [0, 1, 2, 3, 4, 5, 6, 7, 9.23349] + }, + "dynamic": { + "middle": [0, 1, 2, 3, 4.5, 6.5, 8], + "average": [0, 1, 2, 3, 4.29278, 6.31441, 9.23349] + }, + "coarse": { + "middle": [0, 1.5, 3.5, 5.5, 7], + "average": [0, 1.14978, 3.27641, 5.30609, 8.11466] + } + } + } + } +} \ No newline at end of file diff --git a/configs/reduction_32.json b/configs/reduction_32.json new file mode 100644 index 0000000000000000000000000000000000000000..78e65c65a2e55442a08d478664f848b08e962712 --- /dev/null +++ b/configs/reduction_32.json @@ -0,0 +1,56 @@ +{ + "19": { + "qnrf": { + "bins": { + "fine": [ + [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], + [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], + [10, 10], [11, 11], [12, 12], [13, 13], [14, 14], + [15, 15], [16, 16], [17, 17], [18, 18], [19, "inf"] + ], + "dynamic": [ + [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], + [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], + [10, 11], [12, 13], [14, 15], [16, 17], [18, "inf"] + ], + "coarse": [ + [0, 0], [1, 2], [3, 4], [5, 6], [7, 8], + [9, 10], [11, 12], [13, 14], [15, 16], [17, 18], + [19, "inf"] + ] + }, + "anchor_points": { + "fine": { + "middle": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19 + ], + "average": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 23.01897 + ] + }, + "dynamic": { + "middle": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.5, + 12.5, 14.5, 16.5, 18 + ], + "average": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.42903, + 12.43320, 14.43341, 16.43521, 21.93548 + ] + }, + "coarse": { + "middle": [ + 0, 1.5, 3.5, 5.5, 7.5, 9.5, + 11.5, 13.5, 15.5, 17.5, 19 + ], + "average": [ + 0, 1.23498, 3.36108, 5.40298, 7.41406, 9.42356, + 11.43094, 13.43244, 15.43697, 17.43759, 23.01897 + ] + } + } + } + } +} \ No newline at end of file diff --git a/configs/reduction_8.json b/configs/reduction_8.json new file mode 100644 index 0000000000000000000000000000000000000000..8110c3f5660c34f2f6640bd6324e809cebaa3a49 --- /dev/null +++ b/configs/reduction_8.json @@ -0,0 +1,129 @@ +{ + "2": { + "sha": { + "bins": { + "fine": [[0, 0], [1, 1], [2, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2], + "average": [0, 1, 2.24479] + } + } + }, + "shb": { + "bins": { + "fine": [[0, 0], [1, 1], [2, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2], + "average": [0, 1, 2.15171] + } + } + }, + "nwpu": { + "bins": { + "fine": [[0, 0], [1, 1], [2, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2], + "average": [0, 1, 2.10737] + } + } + }, + "qnrf": { + "bins": { + "fine": [[0, 0], [1, 1], [2, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2], + "average": [0, 1, 2.09296] + } + } + }, + "jhu": { + "bins": { + "fine": [[0, 0], [1, 1], [2, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2], + "average": [0, 1, 2.18589] + } + } + } + }, + "4": { + "sha": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4], + "average": [0, 1, 2, 3, 4.29992] + } + } + }, + "shb": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4], + "average": [0, 1, 2, 3, 4.41009] + } + } + }, + "nwpu": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4], + "average": [0, 1, 2, 3, 4.21931] + } + } + }, + "qnrf": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4], + "average": [0, 1, 2, 3, 4.21937] + } + } + }, + "jhu": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4], + "average": [0, 1, 2, 3, 4.24058] + } + } + } + }, + "11": { + "qnrf": { + "bins": { + "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, "inf"]] + }, + "anchor_points": { + "fine": { + "middle": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + "average": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + } + } + } + } +} \ No newline at end of file diff --git a/custom/airport_color.py b/custom/airport_color.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4248b744ec3eedc2c175b34bc728d0e67d4623 --- /dev/null +++ b/custom/airport_color.py @@ -0,0 +1,54 @@ +from PIL import Image, ImageDraw +from custom.json2seg import get_segmentation_by_id +import random +INCHEON = "/home/jungseoik/data/PR/CLIP-EBC/assets/incheon.jpg" +COLOR_PAIR = {1: '빨간색', 2: '주황색', 3: '노란색', 4: '초록색', 5: '빨간색', 6: '초록색'} + +def generate_random_color_pair(): + colors = ['빨간색', '주황색', '노란색', '초록색'] + return {i: random.choice(colors) for i in range(1, 7)} + +def create_mask(segmentation, img_size, color): + mask = Image.new('RGBA', img_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(mask) + + polygon = segmentation[0] + points = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)] + + color_map = { + '빨간색': (255, 0, 0, 128), + '주황색': (255, 165, 0, 128), + '노란색': (255, 255, 0, 128), + '초록색': (0, 255, 0, 128), + '파란색': (0, 0, 255, 128), + '보라색': (128, 0, 128, 128) + } + + draw.polygon(points, fill=color_map[color]) + return mask + +def create_all_masks(img_size, region_color_pairs): + """ + Parameters: + - img_size: 이미지 크기 + - region_color_pairs: Dictionary 형태로 {region_id: color} 매핑 + 예: {1: '빨간색', 2: '초록색', 3: '노란색', ...} + """ + # 최종 마스크 생성 + final_mask = Image.new('RGBA', img_size, (0, 0, 0, 0)) + + # 입력받은 region_color_pairs에 따라 마스크 생성 및 합성 + for region_id, color in region_color_pairs.items(): + segmentation = get_segmentation_by_id(target_id=region_id) + region_mask = create_mask(segmentation, img_size, color) + final_mask = Image.alpha_composite(final_mask, region_mask) + + return final_mask + +def airport_map_color(color_pairs = COLOR_PAIR): + # region_color_pairs = COLOR_PAIR + region_color_pairs = generate_random_color_pair() + image = Image.open(INCHEON) + all_masks = create_all_masks(image.size, region_color_pairs) + result = Image.alpha_composite(image.convert('RGBA'), all_masks) + return result diff --git a/custom/clip_ebc.py b/custom/clip_ebc.py new file mode 100644 index 0000000000000000000000000000000000000000..d163512811cccfac2665f7f5ec2b297ec459886d --- /dev/null +++ b/custom/clip_ebc.py @@ -0,0 +1,346 @@ +import os +import sys +import torch +from torchvision.transforms.functional import normalize, to_pil_image +from torchvision.transforms import ToTensor, Normalize +import matplotlib.pyplot as plt +import json +from models import get_model +from utils import resize_density_map, sliding_window_predict +from PIL import Image +import numpy as np +from scipy.ndimage import gaussian_filter +from sklearn.cluster import KMeans +import datetime +from typing import Optional +from typing import Union + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class ClipEBC: + """ + CLIP-EBC (Efficient Boundary Counting) 이미지 처리 클래스입니다. + + CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한 + 다양한 설정 옵션을 제공합니다. + + Attributes: + truncation (int): 잘라내기 매개변수. 기본값 4. + reduction (int): 축소 비율. 기본값 8. + granularity (str): 세분화 수준. 기본값 "fine". + anchor_points (str): 앵커 포인트 방법. 기본값 "average". + model_name (str): CLIP 모델 이름. 기본값 "clip_vit_b_16". + input_size (int): 입력 이미지 크기. 기본값 224. + window_size (int): 슬라이딩 윈도우 크기. 기본값 224. + stride (int): 슬라이딩 윈도우 이동 간격. 기본값 224. + prompt_type (str): 프롬프트 유형. 기본값 "word". + dataset_name (str): 데이터셋 이름. 기본값 "qnrf". + num_vpt (int): 비주얼 프롬프트 토큰 수. 기본값 32. + vpt_drop (float): 비주얼 프롬프트 토큰 드롭아웃 비율. 기본값 0.0. + deep_vpt (bool): 깊은 비주얼 프롬프트 토큰 사용 여부. 기본값 True. + mean (tuple): 정규화를 위한 평균값. 기본값 (0.485, 0.456, 0.406). + std (tuple): 정규화를 위한 표준편차값. 기본값 (0.229, 0.224, 0.225). + """ + + def __init__(self, + truncation=4, + reduction=8, + granularity="fine", + anchor_points="average", + model_name="clip_vit_b_16", + input_size=224, + window_size=224, + stride=224, + prompt_type="word", + dataset_name="qnrf", + num_vpt=32, + vpt_drop=0., + deep_vpt=True, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + config_dir="configs"): + """CLIPEBC 클래스를 설정 매개변수와 함께 초기화합니다.""" + self.truncation = truncation + self.reduction = reduction + self.granularity = granularity + self.anchor_points_type = anchor_points # 원래 입력값 저장 + self.model_name = model_name + self.input_size = input_size + self.window_size = window_size + self.stride = stride + self.prompt_type = prompt_type + self.dataset_name = dataset_name + self.num_vpt = num_vpt + self.vpt_drop = vpt_drop + self.deep_vpt = deep_vpt + self.mean = mean + self.std = std + self.config_dir = config_dir + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.bins = None + self.anchor_points = None + self.model = None + + # 초기 설정 로드 및 모델 초기화 + self._load_config() + self._initialize_model() + + def _load_config(self): + """설정 파일을 로드하고 bins와 anchor_points를 설정합니다.""" + config_path = os.path.join(self.config_dir, f"reduction_{self.reduction}.json") + with open(config_path, "r") as f: + config = json.load(f)[str(self.truncation)][self.dataset_name] + + self.bins = config["bins"][self.granularity] + self.bins = [(float(b[0]), float(b[1])) for b in self.bins] + + if self.anchor_points_type == "average": + self.anchor_points = config["anchor_points"][self.granularity]["average"] + else: + self.anchor_points = config["anchor_points"][self.granularity]["middle"] + self.anchor_points = [float(p) for p in self.anchor_points] + + def _initialize_model(self): + """CLIP 모델을 초기화합니다.""" + self.model = get_model( + backbone=self.model_name, + input_size=self.input_size, + reduction=self.reduction, + bins=self.bins, + anchor_points=self.anchor_points, + prompt_type=self.prompt_type, + num_vpt=self.num_vpt, + vpt_drop=self.vpt_drop, + deep_vpt=self.deep_vpt + ) + + ckpt_path = "assets/CLIP_EBC_nwpu_rmse.pth" + ckpt = torch.load(ckpt_path, map_location=device) + self.model.load_state_dict(ckpt) + self.model = self.model.to(device) + self.model.eval() + + def visualize_density_map(self, alpha: float = 0.5, save: bool = False, + save_path: Optional[str] = None): + """ + 현재 저장된 예측 결과를 시각화합니다. + + Args: + alpha (float): density map의 투명도 (0~1). 기본값 0.5 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.figure.Figure, np.ndarray]: + - density map이 오버레이된 matplotlib Figure 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + Raises: + ValueError: density_map 또는 processed_image가 None인 경우 (predict 메서드가 실행되지 않은 경우) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + ax.imshow(self.density_map, cmap="jet", alpha=alpha) + ax.axis("off") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_density_{timestamp}.png" + + # 여백 제거하고 저장 + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"Image saved to: {save_path}") + + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + return fig , image_from_plot + + def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97, + save: bool = False, save_path: Optional[str] = None): + """ + 예측된 군중 위치를 점으로 표시하여 시각화합니다. + + Args: + dot_size (int): 점의 크기. 기본값 20 + sigma (float): Gaussian 필터의 sigma 값. 기본값 1 + percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]: + - matplotlib figure의 canvas 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + Raises: + ValueError: density_map 또는 processed_image가 None인 경우 (predict 메서드가 실행되지 않은 경우) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + adjusted_pred_count = int(round(self.count)) + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + + filtered_density = gaussian_filter(self.density_map, sigma=sigma) + + threshold = np.percentile(filtered_density, percentile) + candidate_pixels = np.column_stack(np.where(filtered_density >= threshold)) + + if len(candidate_pixels) > adjusted_pred_count: + kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10) + kmeans.fit(candidate_pixels) + head_positions = kmeans.cluster_centers_.astype(int) + else: + head_positions = candidate_pixels + + y_coords, x_coords = head_positions[:, 0], head_positions[:, 1] + ax.scatter(x_coords, y_coords, + c='red', + s=dot_size, + alpha=1.0, + edgecolors='white', + linewidth=1) + + ax.axis("off") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_dots_{timestamp}.png" + + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"Image saved to: {save_path}") + + # Figure를 numpy 배열로 변환 + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + # plt.close(fig) + # return image_from_plot + return fig.canvas, image_from_plot + + def _process_image(self, image: Union[str, np.ndarray]) -> torch.Tensor: + """ + 이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다. + + Args: + image: 입력 이미지. 다음 형식 중 하나여야 합니다: + - str: 이미지 파일 경로 + - np.ndarray: (H, W, 3) 형태의 RGB 이미지 + - UploadedFile: Streamlit의 업로드된 파일 + + Returns: + torch.Tensor: 전처리된 이미지 텐서, shape (1, 3, H, W) + + Raises: + ValueError: 지원하지 않는 이미지 형식이 입력된 경우 + Exception: 이미지 파일을 열 수 없는 경우 + """ + to_tensor = ToTensor() + normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + # 원본 이미지 저장 + self.original_image = image + + # 입력 타입에 따른 처리 + if isinstance(image, str): + # 파일 경로인 경우 + with open(image, "rb") as f: + pil_image = Image.open(f).convert("RGB") + elif isinstance(image, np.ndarray): + # 넘파이 배열인 경우 + if image.dtype == np.uint8: + pil_image = Image.fromarray(image) + else: + # float 타입인 경우 [0, 1] 범위로 가정하고 변환 + pil_image = Image.fromarray((image * 255).astype(np.uint8)) + else: + # Streamlit UploadedFile 또는 기타 파일 객체인 경우 + try: + pil_image = Image.open(image).convert("RGB") + except Exception as e: + raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e + + # 텐서 변환 및 정규화 + tensor_image = to_tensor(pil_image) + normalized_image = normalize(tensor_image) + batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W) + batched_image = batched_image.to(self.device) + + return batched_image + def _post_process_image(self, image): + """이미지 후처리를 수행합니다.""" + image = normalize(image, mean=(0., 0., 0.), + std=(1. / self.std[0], 1. / self.std[1], 1. / self.std[2])) + image = normalize(image, mean=(-self.mean[0], -self.mean[1], -self.mean[2]), + std=(1., 1., 1.)) + processed_image = to_pil_image(image.squeeze(0)) + return processed_image + + @torch.no_grad() + def predict(self, image: torch.Tensor) -> Image.Image: + """ + 모델 출력 이미지의 후처리를 수행합니다. + + Args: + image (torch.Tensor): 후처리할 이미지 텐서, shape (1, 3, H, W) + + Returns: + PIL.Image.Image: 후처리된 PIL 이미지 + + Note: + 이미지 텐서에 대해 정규화를 역변환하고 PIL 이미지 형식으로 변환합니다. + self.mean과 self.std 값을 사용하여 원본 이미지의 스케일로 복원합니다. + """ + processed_image = self._process_image(image) + image_height, image_width = processed_image.shape[-2:] + processed_image = processed_image.to(self.device) + + pred_density = sliding_window_predict(self.model, processed_image, + self.window_size, self.stride) + pred_count = pred_density.sum().item() + resized_pred_density = resize_density_map(pred_density, + (image_height, image_width)).cpu() + + self.processed_image = self._post_process_image(processed_image) + self.density_map = resized_pred_density.squeeze().numpy() + self.count = pred_count + + return pred_count + + def crowd_count(self): + """ + 가장 최근 예측의 군중 수를 반환합니다. + + Returns: + float: 예측된 군중 수 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.count + + def get_density_map(self): + """ + 가장 최근 예측의 밀도 맵을 반환합니다. + + Returns: + numpy.ndarray: 밀도 맵 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.density_map + \ No newline at end of file diff --git a/custom/clip_ebc_onnx.py b/custom/clip_ebc_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9260394919c10e00aa6907a14b7001ca2504b0 --- /dev/null +++ b/custom/clip_ebc_onnx.py @@ -0,0 +1,465 @@ +import os +import sys +import torch +import numpy as np +import onnxruntime as ort +from typing import Union, Tuple, Optional +from PIL import Image +import matplotlib.pyplot as plt +from torchvision.transforms import ToTensor, Normalize +from torchvision.transforms.functional import normalize, to_pil_image +import json +import datetime +from scipy.ndimage import gaussian_filter +from sklearn.cluster import KMeans +import assets + +# 프로젝트 루트 디렉토리 설정 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +class ClipEBCOnnx: + """ + CLIP-EBC (Efficient Boundary Counting) ONNX 버전 이미지 처리 클래스입니다. + + ONNX로 변환된 CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한 + 다양한 설정 옵션을 제공합니다. + """ + + def __init__(self, + onnx_model_path="assets/CLIP_EBC_nwpu_rmse_onnx.onnx", + truncation=4, + reduction=8, + granularity="fine", + anchor_points="average", + input_size=224, + window_size=224, + stride=224, + dataset_name="qnrf", + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + config_dir="configs"): + """CLIPEBC ONNX 클래스를 설정 매개변수와 함께 초기화합니다.""" + self.onnx_model_path = onnx_model_path + self.truncation = truncation + self.reduction = reduction + self.granularity = granularity + self.anchor_points_type = anchor_points + self.input_size = input_size + self.window_size = window_size + self.stride = stride + self.dataset_name = dataset_name + self.mean = mean + self.std = std + self.config_dir = config_dir + + # 결과 저장용 변수 초기화 + self.density_map = None + self.processed_image = None + self.count = None + self.original_image = None + + # ONNX 추론 세션 설정 + self.session_options = ort.SessionOptions() + self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + # 가능한 경우 GPU 사용 + self.providers = [] + if 'CUDAExecutionProvider' in ort.get_available_providers(): + self.providers.append('CUDAExecutionProvider') + self.providers.append('CPUExecutionProvider') + + # ONNX 런타임 세션 초기화 + print(f"ONNX 모델 로드 중: {self.onnx_model_path}") + self.session = ort.InferenceSession( + self.onnx_model_path, + sess_options=self.session_options, + providers=self.providers + ) + + # 모델의 입력 및 출력 이름 가져오기 + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + + print(f"입력 이름: {self.input_name}, 형태: {self.session.get_inputs()[0].shape}") + print(f"출력 이름: {self.output_name}, 형태: {self.session.get_outputs()[0].shape}") + print(f"실행 제공자: {self.providers}") + + def _process_image(self, image: Union[str, np.ndarray]) -> np.ndarray: + """ + 이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다. + + Args: + image: 입력 이미지. 다음 형식 중 하나여야 합니다: + - str: 이미지 파일 경로 + - np.ndarray: (H, W, 3) 형태의 RGB 이미지 + - UploadedFile: Streamlit의 업로드된 파일 + + Returns: + np.ndarray: 전처리된 이미지 배열, shape (1, 3, H, W) + """ + to_tensor = ToTensor() + normalize = Normalize(mean=self.mean, std=self.std) + + # 원본 이미지 저장 + self.original_image = image + + # 입력 타입에 따른 처리 + if isinstance(image, str): + # 파일 경로인 경우 + with open(image, "rb") as f: + pil_image = Image.open(f).convert("RGB") + elif isinstance(image, np.ndarray): + # 넘파이 배열인 경우 + if image.dtype == np.uint8: + pil_image = Image.fromarray(image) + else: + # float 타입인 경우 [0, 1] 범위로 가정하고 변환 + pil_image = Image.fromarray((image * 255).astype(np.uint8)) + else: + # Streamlit UploadedFile 또는 기타 파일 객체인 경우 + try: + pil_image = Image.open(image).convert("RGB") + except Exception as e: + raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e + + # 텐서 변환 및 정규화 + tensor_image = to_tensor(pil_image) + normalized_image = normalize(tensor_image) + batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W) + + # numpy로 변환 + numpy_image = batched_image.numpy() + + return numpy_image + + def _post_process_image(self, image_tensor): + """이미지 텐서를 PIL 이미지로 변환합니다.""" + # NumPy 배열을 PyTorch 텐서로 변환 + if isinstance(image_tensor, np.ndarray): + image_tensor = torch.from_numpy(image_tensor) + + # 정규화 역변환 + image = normalize( + image_tensor, + mean=[0., 0., 0.], + std=[1./self.std[0], 1./self.std[1], 1./self.std[2]] + ) + + image = normalize( + image, + mean=[-self.mean[0], -self.mean[1], -self.mean[2]], + std=[1., 1., 1.] + ) + + # 배치 차원 제거 및 PIL 이미지로 변환 + processed_image = to_pil_image(image.squeeze(0)) + return processed_image + + def sliding_window_predict(self, image: np.ndarray, window_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]]) -> np.ndarray: + """ + 슬라이딩 윈도우 방식으로 이미지 예측을 수행합니다. 겹치는 영역은 평균값을 사용합니다. + + Args: + image (np.ndarray): 형태가 (1, 3, H, W)인 이미지 배열 + window_size (int or tuple): 윈도우 크기 + stride (int or tuple): 윈도우 이동 간격 + + Returns: + np.ndarray: 예측된 밀도 맵 + """ + # 입력 검증 + assert len(image.shape) == 4, f"이미지는 4차원 배열이어야 합니다. (1, C, H, W), 현재: {image.shape}" + + # 윈도우 크기와 스트라이드 설정 + window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size + stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride + window_size = tuple(window_size) + stride = tuple(stride) + + # 검증 + assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, \ + f"윈도우 크기는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {window_size}" + assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, \ + f"스트라이드는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {stride}" + assert stride[0] <= window_size[0] and stride[1] <= window_size[1], \ + f"스트라이드는 윈도우 크기보다 작아야 합니다. 현재: {stride}와 {window_size}" + + image_height, image_width = image.shape[-2:] + window_height, window_width = window_size + stride_height, stride_width = stride + + # 슬라이딩 윈도우 수 계산 + num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) + num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) + + # 윈도우 추출 + windows = [] + window_positions = [] + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start = i * stride_height, j * stride_width + x_end, y_end = x_start + window_height, y_start + window_width + + # 이미지 경계 처리 + if x_end > image_height: + x_start, x_end = image_height - window_height, image_height + if y_end > image_width: + y_start, y_end = image_width - window_width, image_width + + window = image[:, :, x_start:x_end, y_start:y_end] + windows.append(window) + window_positions.append((x_start, y_start, x_end, y_end)) + + # 배치 단위로 추론 + all_preds = [] + max_batch_size = 8 + + for start_idx in range(0, len(windows), max_batch_size): + end_idx = min(start_idx + max_batch_size, len(windows)) + batch_windows = np.vstack(windows[start_idx:end_idx]) # (batch_size, 3, h, w) + + # ONNX 추론 + ort_inputs = {self.input_name: batch_windows} + batch_preds = self.session.run([self.output_name], ort_inputs)[0] + + # Debug 정보 + # print(f"배치 입력 형태: {batch_windows.shape}, 배치 출력 형태: {batch_preds.shape}") + + all_preds.extend([batch_preds[i:i+1] for i in range(batch_preds.shape[0])]) + + # 예측 결과를 numpy 배열로 변환 + preds = np.concatenate(all_preds, axis=0) + + # 출력 밀도 맵 조립 + pred_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32) + count_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32) + + idx = 0 + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start, x_end, y_end = window_positions[idx] + + # 출력 영역 계산 (reduction 고려) + x_start_out = x_start // self.reduction + y_start_out = y_start // self.reduction + x_end_out = x_end // self.reduction + y_end_out = y_end // self.reduction + + pred_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += preds[idx] + count_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += 1. + idx += 1 + + # 겹치는 영역 평균 계산 + pred_map /= count_map + + return pred_map + + def resize_density_map(self, density_map: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray: + """ + 밀도 맵의 크기를 조정합니다. 총합은 보존됩니다. + + Args: + density_map: 형태가 (C, H, W)인 밀도 맵 + target_size: 목표 크기 (H', W') + + Returns: + np.ndarray: 크기가 조정된 밀도 맵 + """ + from PIL import Image + import torch.nn.functional as F + import torch + + # numpy를 torch로 변환 + if isinstance(density_map, np.ndarray): + density_map = torch.from_numpy(density_map) + + # 배치 차원 추가 + if density_map.dim() == 3: + density_map = density_map.unsqueeze(0) # (1, C, H, W) + + current_size = density_map.shape[2:] + + if current_size[0] == target_size[0] and current_size[1] == target_size[1]: + return density_map.squeeze(0).numpy() + + # 원본 밀도 맵의 총합 계산 + original_sum = density_map.sum() + + # 크기 조정 (쌍선형 보간) + resized_map = F.interpolate( + density_map, + size=target_size, + mode='bilinear', + align_corners=False + ) + + # 총합 보존을 위한 스케일링 + if resized_map.sum() > 0: # 0으로 나누기 방지 + resized_map = resized_map * (original_sum / resized_map.sum()) + + return resized_map.squeeze(0).numpy() + + def predict(self, image: Union[str, np.ndarray]) -> float: + """ + 이미지에서 군중 계수 예측을 수행합니다. + + Args: + image: 입력 이미지 (경로, 넘파이 배열, 또는 업로드된 파일) + + Returns: + float: 예측된 사람 수 + """ + # 이미지 전처리 + processed_image = self._process_image(image) + image_height, image_width = processed_image.shape[-2:] + + # 슬라이딩 윈도우 예측 + pred_density = self.sliding_window_predict( + processed_image, + self.window_size, + self.stride + ) + + # 예측 결과 저장 + pred_count = pred_density.sum() + + # 원본 이미지 크기로 밀도 맵 조정 + resized_pred_density = self.resize_density_map( + pred_density, + (image_height, image_width) + ) + + # 결과 저장 + self.processed_image = self._post_process_image(processed_image) + self.density_map = resized_pred_density.squeeze() + self.count = pred_count + + return pred_count + + def visualize_density_map(self, alpha: float = 0.5, save: bool = False, + save_path: Optional[str] = None): + """ + 현재 저장된 예측 결과를 시각화합니다. + + Args: + alpha (float): density map의 투명도 (0~1). 기본값 0.5 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.figure.Figure, np.ndarray]: + - density map이 오버레이된 matplotlib Figure 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + ax.imshow(self.density_map, cmap="jet", alpha=alpha) + ax.axis("off") + plt.title(f"Count: {self.count:.1f}") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_density_{timestamp}.png" + + # 여백 제거하고 저장 + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"이미지 저장 완료: {save_path}") + + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + return fig, image_from_plot + + def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97, + save: bool = False, save_path: Optional[str] = None): + """ + 예측된 군중 위치를 점으로 표시하여 시각화합니다. + + Args: + dot_size (int): 점의 크기. 기본값 20 + sigma (float): Gaussian 필터의 sigma 값. 기본값 1 + percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]: + - matplotlib figure의 canvas 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + adjusted_pred_count = int(round(self.count)) + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + + filtered_density = gaussian_filter(self.density_map, sigma=sigma) + + threshold = np.percentile(filtered_density, percentile) + candidate_pixels = np.column_stack(np.where(filtered_density >= threshold)) + + if len(candidate_pixels) > adjusted_pred_count: + kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10) + kmeans.fit(candidate_pixels) + head_positions = kmeans.cluster_centers_.astype(int) + else: + head_positions = candidate_pixels + + y_coords, x_coords = head_positions[:, 0], head_positions[:, 1] + ax.scatter(x_coords, y_coords, + c='red', + s=dot_size, + alpha=1.0, + edgecolors='white', + linewidth=1) + + ax.axis("off") + plt.title(f"Count: {self.count:.1f}") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_dots_{timestamp}.png" + + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"이미지 저장 완료: {save_path}") + + # Figure를 numpy 배열로 변환 + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + return fig.canvas, image_from_plot + + def crowd_count(self): + """ + 가장 최근 예측의 군중 수를 반환합니다. + + Returns: + float: 예측된 군중 수 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.count + + def get_density_map(self): + """ + 가장 최근 예측의 밀도 맵을 반환합니다. + + Returns: + numpy.ndarray: 밀도 맵 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.density_map \ No newline at end of file diff --git a/custom/clip_ebc_tensorrt.py b/custom/clip_ebc_tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..41f2d51665b1dc309257a44a2d8844a77eb87caa --- /dev/null +++ b/custom/clip_ebc_tensorrt.py @@ -0,0 +1,603 @@ +import os +import sys +import torch +import numpy as np +import tensorrt as trt +from typing import Union, Tuple, Optional +from PIL import Image +import matplotlib.pyplot as plt +from torchvision.transforms import ToTensor, Normalize +from torchvision.transforms.functional import normalize, to_pil_image +import json +import datetime +from scipy.ndimage import gaussian_filter +from sklearn.cluster import KMeans +import assets + +# 프로젝트 루트 디렉토리 설정 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +class ClipEBCTensorRT: + """ + CLIP-EBC (Efficient Boundary Counting) TensorRT 버전 이미지 처리 클래스입니다. + + TensorRT로 변환된 CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한 + 다양한 설정 옵션을 제공합니다. + """ + + def __init__(self, + engine_path="assets/CLIP_EBC_nwpu_rmse_tensorrt.trt", + truncation=4, + reduction=8, + granularity="fine", + anchor_points="average", + input_size=224, + window_size=224, + stride=224, + dataset_name="qnrf", + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + config_dir ="configs"): + """CLIPEBC TensorRT 클래스를 설정 매개변수와 함께 초기화합니다.""" + self.engine_path = engine_path + self.truncation = truncation + self.reduction = reduction + self.granularity = granularity + self.anchor_points_type = anchor_points + self.input_size = input_size + self.window_size = window_size + self.stride = stride + self.dataset_name = dataset_name + self.mean = mean + self.std = std + self.config_dir = config_dir + + # 결과 저장용 변수 초기화 + self.density_map = None + self.processed_image = None + self.count = None + self.original_image = None + + # TensorRT 엔진 로드 + print(f"TensorRT 엔진 로드 중: {self.engine_path}") + self._load_engine() + + # 입력 및 출력 이름 설정 + self.input_name = "input" + self.output_name = "output" + + print(f"TensorRT 엔진 초기화 완료") + + def _load_engine(self): + """TensorRT 엔진을 로드합니다.""" + # TensorRT 로거 생성 + TRT_LOGGER = trt.Logger(trt.Logger.WARNING) + + # 런타임 생성 + self.runtime = trt.Runtime(TRT_LOGGER) + + # 엔진 파일 로드 + with open(self.engine_path, 'rb') as f: + engine_data = f.read() + + # 직렬화된 엔진에서 엔진 생성 + self.engine = self.runtime.deserialize_cuda_engine(engine_data) + + # 실행 컨텍스트 생성 + self.context = self.engine.create_execution_context() + + # TensorRT 10.x에서는 input_binding/output_binding 대신 네트워크 구조를 확인 + # 입력과 출력을 가져오는 방법이 변경됨 + self.num_io_tensors = self.engine.num_io_tensors + + # 입력과 출력 텐서 이름 찾기 + self.input_tensor_names = [] + self.output_tensor_names = [] + + print(f"TensorRT 엔진에서 {self.num_io_tensors}개의 IO 텐서를 찾았습니다") + + for i in range(self.num_io_tensors): + name = self.engine.get_tensor_name(i) + is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT + + if is_input: + self.input_tensor_names.append(name) + else: + self.output_tensor_names.append(name) + + # 입력과 출력 이름 설정 + if not self.input_tensor_names: + raise ValueError("엔진에서 입력 텐서를 찾을 수 없습니다.") + if not self.output_tensor_names: + raise ValueError("엔진에서 출력 텐서를 찾을 수 없습니다.") + + # 기본 입력 및 출력 이름 설정 + self.input_name = self.input_tensor_names[0] + self.output_name = self.output_tensor_names[0] + + # 입출력 형태 추출 + self.input_shape = self.engine.get_tensor_shape(self.input_name) + self.output_shape = self.engine.get_tensor_shape(self.output_name) + + print(f"입력 이름: {self.input_name}, 형태: {self.input_shape}") + print(f"출력 이름: {self.output_name}, 형태: {self.output_shape}") + + def _process_image(self, image: Union[str, np.ndarray]) -> np.ndarray: + """ + 이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다. + + Args: + image: 입력 이미지. 다음 형식 중 하나여야 합니다: + - str: 이미지 파일 경로 + - np.ndarray: (H, W, 3) 형태의 RGB 이미지 + - UploadedFile: Streamlit의 업로드된 파일 + + Returns: + np.ndarray: 전처리된 이미지 배열, shape (1, 3, H, W) + """ + to_tensor = ToTensor() + normalize = Normalize(mean=self.mean, std=self.std) + + # 원본 이미지 저장 + self.original_image = image + + # 입력 타입에 따른 처리 + if isinstance(image, str): + # 파일 경로인 경우 + with open(image, "rb") as f: + pil_image = Image.open(f).convert("RGB") + elif isinstance(image, np.ndarray): + # 넘파이 배열인 경우 + if image.dtype == np.uint8: + pil_image = Image.fromarray(image) + else: + # float 타입인 경우 [0, 1] 범위로 가정하고 변환 + pil_image = Image.fromarray((image * 255).astype(np.uint8)) + else: + # Streamlit UploadedFile 또는 기타 파일 객체인 경우 + try: + pil_image = Image.open(image).convert("RGB") + except Exception as e: + raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e + + # 텐서 변환 및 정규화 + tensor_image = to_tensor(pil_image) + normalized_image = normalize(tensor_image) + batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W) + + # numpy로 변환 + numpy_image = batched_image.numpy() + + return numpy_image + + def _post_process_image(self, image_tensor): + """이미지 텐서를 PIL 이미지로 변환합니다.""" + # NumPy 배열을 PyTorch 텐서로 변환 + if isinstance(image_tensor, np.ndarray): + image_tensor = torch.from_numpy(image_tensor) + + # 정규화 역변환 + image = normalize( + image_tensor, + mean=[0., 0., 0.], + std=[1./self.std[0], 1./self.std[1], 1./self.std[2]] + ) + + image = normalize( + image, + mean=[-self.mean[0], -self.mean[1], -self.mean[2]], + std=[1., 1., 1.] + ) + + # 배치 차원 제거 및 PIL 이미지로 변환 + processed_image = to_pil_image(image.squeeze(0)) + return processed_image + def _infer_batch(self, batch_input): + """ + TensorRT 엔진을 사용하여 배치 추론을 수행합니다. (수정 버전) + """ + import pycuda.driver as cuda + import pycuda.autoinit + import numpy as np + + batch_size = batch_input.shape[0] + + # 입력의 형태와 데이터 타입 확인 + input_shape = (batch_size, 3, self.input_size, self.input_size) + print(f"입력 배치 형태: {batch_input.shape}, 데이터 타입: {batch_input.dtype}") + + # 입력 형태 검증 + if batch_input.shape != input_shape: + print(f"경고: 입력 형태 불일치. 예상: {input_shape}, 실제: {batch_input.shape}") + # 필요시 형태 수정 + batch_input = np.resize(batch_input, input_shape) + + # 데이터 타입 검증 + if batch_input.dtype != np.float32: + print(f"경고: 입력 데이터 타입 불일치. float32로 변환합니다.") + batch_input = batch_input.astype(np.float32) + + # 동적 배치 크기 설정 + self.context.set_input_shape(self.input_name, input_shape) + + # 출력 형태 가져오기 + output_shape = self.context.get_tensor_shape(self.output_name) + output_shape = tuple(output_shape) # 튜플로 변환하여 안전성 보장 + print(f"출력 형태: {output_shape}") + + # -1 값을 실제 배치 크기로 대체 + if output_shape[0] == -1: + output_shape = (batch_size,) + output_shape[1:] + + # 출력 버퍼 준비 + output = np.empty(output_shape, dtype=np.float32) + + # 호스트 메모리 준비 (페이지 잠금 메모리 사용) + h_input = cuda.pagelocked_empty(batch_input.shape, dtype=np.float32) + h_output = cuda.pagelocked_empty(output_shape, dtype=np.float32) + + # 입력 데이터 복사 + np.copyto(h_input, batch_input) + + # 디바이스 메모리 할당 + d_input = cuda.mem_alloc(h_input.nbytes) + d_output = cuda.mem_alloc(h_output.nbytes) + + # CUDA 스트림 생성 + stream = cuda.Stream() + + try: + # 메모리 복사 (호스트 -> 디바이스) + cuda.memcpy_htod_async(d_input, h_input, stream) + + # 텐서 주소 설정 + self.context.set_tensor_address(self.input_name, int(d_input)) + self.context.set_tensor_address(self.output_name, int(d_output)) + + # 디버깅 정보 (메모리 주소) + print(f"입력 메모리 주소: {int(d_input)}, 출력 메모리 주소: {int(d_output)}") + + # 실행 + success = self.context.execute_async_v3(stream_handle=stream.handle) + if not success: + print("TensorRT 실행 실패") + return None + + # 메모리 복사 (디바이스 -> 호스트) + cuda.memcpy_dtoh_async(h_output, d_output, stream) + + # 스트림 동기화 + stream.synchronize() + + # 출력 데이터 복사 + np.copyto(output, h_output) + + return output + + except Exception as e: + print(f"TensorRT 추론 중 오류 발생: {str(e)}") + import traceback + traceback.print_exc() + return None + + finally: + # 메모리 해제 + del stream + if 'd_input' in locals(): + d_input.free() + if 'd_output' in locals(): + d_output.free() + + def sliding_window_predict(self, image: np.ndarray, window_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]]) -> np.ndarray: + """ + 슬라이딩 윈도우 방식으로 이미지 예측을 수행합니다. 겹치는 영역은 평균값을 사용합니다. + + Args: + image (np.ndarray): 형태가 (1, 3, H, W)인 이미지 배열 + window_size (int or tuple): 윈도우 크기 + stride (int or tuple): 윈도우 이동 간격 + + Returns: + np.ndarray: 예측된 밀도 맵 + """ + # CUDA 초기화 (처음 사용할 때만) + global cuda + if 'cuda' not in globals(): + import pycuda.driver as cuda + cuda.init() + + # 입력 검증 + assert len(image.shape) == 4, f"이미지는 4차원 배열이어야 합니다. (1, C, H, W), 현재: {image.shape}" + + # 윈도우 크기와 스트라이드 설정 + window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size + stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride + window_size = tuple(window_size) + stride = tuple(stride) + + # 검증 + assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, \ + f"윈도우 크기는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {window_size}" + assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, \ + f"스트라이드는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {stride}" + assert stride[0] <= window_size[0] and stride[1] <= window_size[1], \ + f"스트라이드는 윈도우 크기보다 작아야 합니다. 현재: {stride}와 {window_size}" + + image_height, image_width = image.shape[-2:] + window_height, window_width = window_size + stride_height, stride_width = stride + + # 슬라이딩 윈도우 수 계산 + num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) + num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) + + # 윈도우 추출 + windows = [] + window_positions = [] + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start = i * stride_height, j * stride_width + x_end, y_end = x_start + window_height, y_start + window_width + + # 이미지 경계 처리 + if x_end > image_height: + x_start, x_end = image_height - window_height, image_height + if y_end > image_width: + y_start, y_end = image_width - window_width, image_width + + window = image[:, :, x_start:x_end, y_start:y_end] + windows.append(window) + window_positions.append((x_start, y_start, x_end, y_end)) + + # 배치 단위로 추론 + all_preds = [] + max_batch_size = 8 + + for start_idx in range(0, len(windows), max_batch_size): + end_idx = min(start_idx + max_batch_size, len(windows)) + batch_windows = np.vstack(windows[start_idx:end_idx]) # (batch_size, 3, h, w) + + # TensorRT 추론 + batch_preds = self._infer_batch(batch_windows) + + # Debug 정보 + # print(f"배치 입력 형태: {batch_windows.shape}, 배치 출력 형태: {batch_preds.shape}") + + all_preds.extend([batch_preds[i:i+1] for i in range(batch_preds.shape[0])]) + + # 예측 결과를 numpy 배열로 변환 + preds = np.concatenate(all_preds, axis=0) + + # 출력 밀도 맵 조립 + pred_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32) + count_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32) + + idx = 0 + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start, x_end, y_end = window_positions[idx] + + # 출력 영역 계산 (reduction 고려) + x_start_out = x_start // self.reduction + y_start_out = y_start // self.reduction + x_end_out = x_end // self.reduction + y_end_out = y_end // self.reduction + + pred_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += preds[idx] + count_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += 1. + idx += 1 + + # 겹치는 영역 평균 계산 + pred_map /= count_map + + return pred_map + + def resize_density_map(self, density_map: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray: + """ + 밀도 맵의 크기를 조정합니다. 총합은 보존됩니다. + + Args: + density_map: 형태가 (C, H, W)인 밀도 맵 + target_size: 목표 크기 (H', W') + + Returns: + np.ndarray: 크기가 조정된 밀도 맵 + """ + from PIL import Image + import torch.nn.functional as F + import torch + + # numpy를 torch로 변환 + if isinstance(density_map, np.ndarray): + density_map = torch.from_numpy(density_map) + + # 배치 차원 추가 + if density_map.dim() == 3: + density_map = density_map.unsqueeze(0) # (1, C, H, W) + + current_size = density_map.shape[2:] + + if current_size[0] == target_size[0] and current_size[1] == target_size[1]: + return density_map.squeeze(0).numpy() + + # 원본 밀도 맵의 총합 계산 + original_sum = density_map.sum() + + # 크기 조정 (쌍선형 보간) + resized_map = F.interpolate( + density_map, + size=target_size, + mode='bilinear', + align_corners=False + ) + + # 총합 보존을 위한 스케일링 + if resized_map.sum() > 0: # 0으로 나누기 방지 + resized_map = resized_map * (original_sum / resized_map.sum()) + + return resized_map.squeeze(0).numpy() + + def predict(self, image: Union[str, np.ndarray]) -> float: + """ + 이미지에서 군중 계수 예측을 수행합니다. + + Args: + image: 입력 이미지 (경로, 넘파이 배열, 또는 업로드된 파일) + + Returns: + float: 예측된 사람 수 + """ + # 이미지 전처리 + processed_image = self._process_image(image) + image_height, image_width = processed_image.shape[-2:] + + # 슬라이딩 윈도우 예측 + pred_density = self.sliding_window_predict( + processed_image, + self.window_size, + self.stride + ) + + # 예측 결과 저장 + pred_count = pred_density.sum() + + # 원본 이미지 크기로 밀도 맵 조정 + resized_pred_density = self.resize_density_map( + pred_density, + (image_height, image_width) + ) + + # 결과 저장 + self.processed_image = self._post_process_image(processed_image) + self.density_map = resized_pred_density.squeeze() + self.count = pred_count + + return pred_count + + def visualize_density_map(self, alpha: float = 0.5, save: bool = False, + save_path: Optional[str] = None): + """ + 현재 저장된 예측 결과를 시각화합니다. + + Args: + alpha (float): density map의 투명도 (0~1). 기본값 0.5 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.figure.Figure, np.ndarray]: + - density map이 오버레이된 matplotlib Figure 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + ax.imshow(self.density_map, cmap="jet", alpha=alpha) + ax.axis("off") + plt.title(f"Count: {self.count:.1f}") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_density_{timestamp}.png" + + # 여백 제거하고 저장 + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"이미지 저장 완료: {save_path}") + + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + return fig, image_from_plot + + def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97, + save: bool = False, save_path: Optional[str] = None): + """ + 예측된 군중 위치를 점으로 표시하여 시각화합니다. + + Args: + dot_size (int): 점의 크기. 기본값 20 + sigma (float): Gaussian 필터의 sigma 값. 기본값 1 + percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97 + save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False + save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장. + 기본값 None + + Returns: + Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]: + - matplotlib figure의 canvas 객체 + - RGB 형식의 시각화된 이미지 배열 (H, W, 3) + """ + if self.density_map is None or self.processed_image is None: + raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.") + + adjusted_pred_count = int(round(self.count)) + + fig, ax = plt.subplots(dpi=200, frameon=False) + ax.imshow(self.processed_image) + + filtered_density = gaussian_filter(self.density_map, sigma=sigma) + + threshold = np.percentile(filtered_density, percentile) + candidate_pixels = np.column_stack(np.where(filtered_density >= threshold)) + + if len(candidate_pixels) > adjusted_pred_count: + kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10) + kmeans.fit(candidate_pixels) + head_positions = kmeans.cluster_centers_.astype(int) + else: + head_positions = candidate_pixels + + y_coords, x_coords = head_positions[:, 0], head_positions[:, 1] + ax.scatter(x_coords, y_coords, + c='red', + s=dot_size, + alpha=1.0, + edgecolors='white', + linewidth=1) + + ax.axis("off") + plt.title(f"Count: {self.count:.1f}") + + if save: + if save_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"crowd_dots_{timestamp}.png" + + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200) + print(f"이미지 저장 완료: {save_path}") + + # Figure를 numpy 배열로 변환 + fig.canvas.draw() + image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + image_from_plot = image_from_plot[:,:,:3] # RGB로 변환 + + return fig.canvas, image_from_plot + + def crowd_count(self): + """ + 가장 최근 예측의 군중 수를 반환합니다. + + Returns: + float: 예측된 군중 수 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.count + + def get_density_map(self): + """ + 가장 최근 예측의 밀도 맵을 반환합니다. + + Returns: + numpy.ndarray: 밀도 맵 + None: 아직 예측이 수행되지 않은 경우 + """ + return self.density_map \ No newline at end of file diff --git a/custom/init_get_model.py b/custom/init_get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom/json2seg.py b/custom/json2seg.py new file mode 100644 index 0000000000000000000000000000000000000000..d26c4386af82d09743ee6172536d4cd4ced77428 --- /dev/null +++ b/custom/json2seg.py @@ -0,0 +1,16 @@ +import json + +def get_segmentation_by_id(target_id, json_file="/home/jungseoik/data/PR/CLIP-EBC/assets/seg.json" ): + with open(json_file, "r", encoding="utf-8") as f: + data = json.load(f) + + # annotations 리스트 가져오기 + annotations = data.get("annotations", []) + + # annotations 순회하면서 id가 target_id인 항목 찾기 + for ann in annotations: + if ann.get("id") == target_id: + return ann.get("segmentation", None) + + # 해당 id가 없으면 None 반환 + return None \ No newline at end of file diff --git a/custom/mock_gen.py b/custom/mock_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..72e93e78fb08396f92de9e380e4a28a5a81becc9 --- /dev/null +++ b/custom/mock_gen.py @@ -0,0 +1,80 @@ +import pandas as pd +import numpy as np +import random +def create_mock_data_heatmap(): + # 기본 구조 생성 + sections = [f'구역 {i}' for i in range(1, 7)] + months = list(range(1, 13)) + years = list(range(2020, 2025)) + + # 데이터프레임용 리스트 생성 + data = [] + for year in years: + for section in sections: + for month in months: + data.append({ + 'section': section, + 'month': month, + 'year': year, + 'crowd_count': np.random.randint(30000, 500000) + }) + + # DataFrame 생성 + df = pd.DataFrame(data) + return df + +def create_mock_data_table(): + mock_data = { + 'section': [f'구역 {i}' for i in range(1, 7)], + 'count': np.random.randint(10000, 300000, 6) +} + + df = pd.DataFrame(mock_data) + return df + +def create_mock_data_donut(min_value=10000, max_value=500000): + """ + 가상의 인구 이동 데이터를 생성합니다. + + Returns: + tuple: (인바운드 이동 비율, 아웃바운드 이동 비율) + """ + # 랜덤 값 생성 (10000~500000 사이) + inbound = random.randint(min_value, max_value) + outbound = random.randint(min_value, max_value) + + # 전체 값 대비 비율 계산 (0-100 사이의 값으로 변환) + total = inbound + outbound + inbound_percent = round((inbound / total) * 100) + outbound_percent = round((outbound / total) * 100) + + return inbound_percent, outbound_percent + + +def create_mock_data_inout(): + """ + 방문객 데이터 랜덤 생성 + - 이번달 방문객: 150,000 ~ 500,000 + - 오늘 방문객: 5,000 ~ 100,000 + - delta는 전월/전일 대비 증감량 (-30% ~ +30%) + """ + # 이번달 방문객 (더 큰 범위) + monthly_visitors = random.randint(150000, 500000) + monthly_delta = int(monthly_visitors * random.uniform(-0.3, 0.3)) # 30% 범위 내 증감 + + # 오늘 방문객 (더 작은 범위) + daily_visitors = random.randint(5000, 100000) + daily_delta = int(daily_visitors * random.uniform(-0.3, 0.3)) # 30% 범위 내 증감 + + return { + 'top': { + 'state': '이번달 방문객', + 'visitor': monthly_visitors, + 'delta': monthly_delta + }, + 'bottom': { + 'state': '오늘 방문객', + 'visitor': daily_visitors, + 'delta': daily_delta + } + } \ No newline at end of file diff --git a/custom/visual.py b/custom/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a1fde20f7b893f86e9eeca4e0212a1b6b7d0fd --- /dev/null +++ b/custom/visual.py @@ -0,0 +1,100 @@ +import altair as alt +import pandas as pd +from typing import Tuple, Literal, Union +# Heatmap +def make_heatmap(input_df, input_y, input_x, input_color, input_color_theme): + heatmap = alt.Chart(input_df).mark_rect().encode( + y=alt.Y(f'{input_y}:O', axis=alt.Axis(title="Month", titleFontSize=18, titlePadding=15, titleFontWeight=900, labelAngle=0)), + x=alt.X(f'{input_x}:O', axis=alt.Axis(title="", titleFontSize=18, titlePadding=15, titleFontWeight=900, labelAngle=0)), + color=alt.Color(f'max({input_color}):Q', + legend=None, + scale=alt.Scale(scheme=input_color_theme)), + stroke=alt.value('black'), + strokeWidth=alt.value(0.25), + ).properties(width=900 + ).configure_axis( + labelFontSize=12, + titleFontSize=12 + ) + # height=300 + return heatmap + + +# Donut chart +def make_donut( + input_response: float, + input_text: str, + input_color: Literal['blue', 'green', 'orange', 'red'] +) -> alt.LayerChart: + """ + Altair를 사용하여 지정된 퍼센트, 레이블, 색상 스키마로 도넛 차트를 생성합니다. + + 함수 구조: + 1. 입력 색상에 따른 색상 스키마 정의 + 2. 두 개의 DataFrame 생성: + - 퍼센트 표시를 위한 메인 데이터 + - 전체 원을 위한 배경 데이터 + 3. 세 개의 레이어 생성: + - 배경 원 (plot_bg) + - 퍼센트 호 (plot) + - 중앙 텍스트 표시 + + 매개변수: + ---------- + input_response : float + 표시할 퍼센트 값 (0-100 사이) + input_text : str + 차트에 표시할 레이블 텍스트 + input_color : str + 사용할 색상 스키마 ('blue', 'green', 'orange', 'red' 중 하나) + + 반환값: + ------- + alt.LayerChart + 배경, 퍼센트 호, 중앙 텍스트가 결합된 Altair 레이어 차트 + + 사용 예시: + --------- + >>> chart = make_donut(75, "완료", "blue") + >>> chart.save('donut.html') + """ + if input_color == 'blue': + chart_color = ['#29b5e8', '#155F7A'] + if input_color == 'green': + chart_color = ['#27AE60', '#12783D'] + if input_color == 'orange': + chart_color = ['#F39C12', '#875A12'] + if input_color == 'red': + chart_color = ['#E74C3C', '#781F16'] + + source = pd.DataFrame({ + "Topic": ['', input_text], + "% value": [100-input_response, input_response] + }) + source_bg = pd.DataFrame({ + "Topic": ['', input_text], + "% value": [100, 0] + }) + + plot = alt.Chart(source).mark_arc(innerRadius=45, cornerRadius=25).encode( + theta="% value", + color= alt.Color("Topic:N", + scale=alt.Scale( + #domain=['A', 'B'], + domain=[input_text, ''], + # range=['#29b5e8', '#155F7A']), # 31333F + range=chart_color), + legend=None), + ).properties(width=130, height=130) + + text = plot.mark_text(align='center', color="#29b5e8", font="Lato", fontSize=32, fontWeight=700, fontStyle="italic").encode(text=alt.value(f'{input_response} %')) + plot_bg = alt.Chart(source_bg).mark_arc(innerRadius=45, cornerRadius=20).encode( + theta="% value", + color= alt.Color("Topic:N", + scale=alt.Scale( + # domain=['A', 'B'], + domain=[input_text, ''], + range=chart_color), # 31333F + legend=None), + ).properties(width=130, height=130) + return plot_bg + plot + text \ No newline at end of file diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..4aa93e81c8f5ccd020754ca1b706ce6f5527fb6f --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1,7 @@ +from .dm_loss import DMLoss +from .dace_loss import DACELoss + +__all__ = [ + "DMLoss", + "DACELoss", +] diff --git a/losses/bregman_pytorch.py b/losses/bregman_pytorch.py new file mode 100755 index 0000000000000000000000000000000000000000..d24082d42162ae5cada3597055c4a2d3a31467df --- /dev/null +++ b/losses/bregman_pytorch.py @@ -0,0 +1,144 @@ +# Code modified from https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/bregman_pytorch.py +import torch +from torch import Tensor +from torch.cuda.amp import autocast +from typing import Union, Tuple, Dict + +M_EPS = 1e-16 + + +@autocast(enabled=True, dtype=torch.float32) # avoid numerical instability +def sinkhorn( + a: Tensor, + b: Tensor, + C: Tensor, + reg: float = 1e-1, + maxIter: int = 1000, + stopThr: float = 1e-9, + verbose: bool = False, + log: bool = True, + eval_freq: int = 10, + print_freq: int = 200, +) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: + """ + Solve the entropic regularization optimal transport + The input should be PyTorch tensors + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + - C is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are target and source measures (sum to 1) + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. + + Parameters + ---------- + a : torch.tensor (na,) + samples measure in the target domain + b : torch.tensor (nb,) + samples in the source domain + C : torch.tensor (na,nb) + loss matrix + reg : float + Regularization term > 0 + maxIter : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error ( > 0 ) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (na x nb) torch.tensor + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + See Also + -------- + """ + + device = a.device + na, nb = C.shape + + # a = a.view(-1, 1) + # b = b.view(-1, 1) + + assert na >= 1 and nb >= 1, f"C needs to be 2d. Found C.shape = {C.shape}" + assert na == a.shape[0] and nb == b.shape[0], f"Shape of a ({a.shape}) or b ({b.shape}) does not match that of C ({C.shape})" + assert reg > 0, f"reg should be greater than 0. Found reg = {reg}" + assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}" + + if log: + log = {"err": []} + + u = torch.ones((na), dtype=a.dtype).to(device) / na + v = torch.ones((nb), dtype=b.dtype).to(device) / nb + + K = torch.empty(C.shape, dtype=C.dtype).to(device) + torch.div(C, -reg, out=K) + torch.exp(K, out=K) + + b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) + + it = 1 + err = 1 + + # allocate memory beforehand + KTu = torch.empty(v.shape, dtype=v.dtype).to(device) + Kv = torch.empty(u.shape, dtype=u.dtype).to(device) + + while (err > stopThr and it <= maxIter): + upre, vpre = u, v + # torch.matmul(u, K, out=KTu) + KTu = torch.matmul(u.view(1, -1), K).view(-1) + v = torch.div(b, KTu + M_EPS) + # torch.matmul(K, v, out=Kv) + Kv = torch.matmul(K, v.view(-1, 1)).view(-1) + u = torch.div(a, Kv + M_EPS) + + if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ + torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): + print("Warning: numerical errors at iteration", it) + u, v = upre, vpre + break + + if log and it % eval_freq == 0: + # we can speed up the process by checking for the error only all + # the eval_freq iterations + # below is equivalent to: + # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) + # but with more memory efficient + b_hat = (torch.matmul(u.view(1, -1), K) * v.view(1, -1)).view(-1) + err = (b - b_hat).pow(2).sum().item() + # err = (b - b_hat).abs().sum().item() + log["err"].append(err) + + if verbose and it % print_freq == 0: + print("iteration {:5d}, constraint error {:5e}".format(it, err)) + + it += 1 + + if log: + log["u"] = u + log["v"] = v + log["alpha"] = reg * torch.log(u + M_EPS) + log["beta"] = reg * torch.log(v + M_EPS) + + # transport plan + P = u.reshape(-1, 1) * K * v.reshape(1, -1) + if log: + return P, log + else: + return P diff --git a/losses/dace_loss.py b/losses/dace_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..80ff992a443c2b7dc8b2e11ad939254c17b3b955 --- /dev/null +++ b/losses/dace_loss.py @@ -0,0 +1,70 @@ +import torch +from torch import nn, Tensor +from typing import Any, List, Tuple, Dict + +from .dm_loss import DMLoss +from .utils import _reshape_density + + +class DACELoss(nn.Module): + def __init__( + self, + bins: List[Tuple[float, float]], + reduction: int, + weight_count_loss: float = 1.0, + count_loss: str = "mae", + **kwargs: Any + ) -> None: + super().__init__() + assert len(bins) > 0, f"Expected at least one bin, got {bins}" + assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}" + assert all([b[0] <= b[1] for b in bins]), f"Expected all bins to be in increasing order, got {bins}" + self.bins = bins + self.reduction = reduction + self.cross_entropy_fn = nn.CrossEntropyLoss(reduction="none") + + count_loss = count_loss.lower() + assert count_loss in ["mae", "mse", "dmcount"], f"Expected count_loss to be one of ['mae', 'mse', 'dmcount'], got {count_loss}" + self.count_loss = count_loss + if self.count_loss == "mae": + self.use_dm_loss = False + self.count_loss_fn = nn.L1Loss(reduction="none") + elif self.count_loss == "mse": + self.use_dm_loss = False + self.count_loss_fn = nn.MSELoss(reduction="none") + else: + self.use_dm_loss = True + assert "input_size" in kwargs, f"Expected input_size to be in kwargs when count_loss='dmcount', got {kwargs}" + self.count_loss_fn = DMLoss(reduction=reduction, **kwargs) + + self.weight_count_loss = weight_count_loss + + def _bin_count(self, density_map: Tensor) -> Tensor: + class_map = torch.zeros_like(density_map, dtype=torch.long) + for idx, (low, high) in enumerate(self.bins): + mask = (density_map >= low) & (density_map <= high) + class_map[mask] = idx + return class_map.squeeze(1) # remove channel dimension + + def forward(self, pred_class: Tensor, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: + target_density = _reshape_density(target_density, reduction=self.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density + assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}" + + target_class = self._bin_count(target_density) + + cross_entropy_loss = self.cross_entropy_fn(pred_class, target_class).sum(dim=(-1, -2)).mean() + + if self.use_dm_loss: + count_loss, loss_info = self.count_loss_fn(pred_density, target_density, target_points) + loss_info["ce_loss"] = cross_entropy_loss.detach() + else: + count_loss = self.count_loss_fn(pred_density, target_density).sum(dim=(-1, -2, -3)).mean() + loss_info = { + "ce_loss": cross_entropy_loss.detach(), + f"{self.count_loss}_loss": count_loss.detach(), + } + + loss = cross_entropy_loss + self.weight_count_loss * count_loss + loss_info["loss"] = loss.detach() + + return loss, loss_info diff --git a/losses/dm_loss.py b/losses/dm_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..5ab72a496867bb1464a73b97be007014d58153d0 --- /dev/null +++ b/losses/dm_loss.py @@ -0,0 +1,124 @@ +import torch +from torch import nn, Tensor +from torch.cuda.amp import autocast +from typing import List, Any, Tuple, Dict + +from .bregman_pytorch import sinkhorn +from .utils import _reshape_density + +EPS = 1e-8 + + +class OTLoss(nn.Module): + def __init__( + self, + input_size: int, + reduction: int, + norm_cood: bool, + num_of_iter_in_ot: int = 100, + reg: float = 10.0 + ) -> None: + super().__init__() + assert input_size % reduction == 0 + + self.input_size = input_size + self.reduction = reduction + self.norm_cood = norm_cood + self.num_of_iter_in_ot = num_of_iter_in_ot + self.reg = reg + + # coordinate is same to image space, set to constant since crop size is same + self.cood = torch.arange(0, input_size, step=reduction, dtype=torch.float32) + reduction / 2 + self.density_size = self.cood.size(0) + self.cood.unsqueeze_(0) # [1, #cood] + self.cood = self.cood / input_size * 2 - 1 if self.norm_cood else self.cood + self.output_size = self.cood.size(1) + + @autocast(enabled=True, dtype=torch.float32) # avoid numerical instability + def forward(self, pred_density: Tensor, normed_pred_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, float, Tensor]: + batch_size = normed_pred_density.size(0) + assert len(target_points) == batch_size, f"Expected target_points to have length {batch_size}, but got {len(target_points)}" + assert self.output_size == normed_pred_density.size(2) + device = pred_density.device + + loss = torch.zeros([1]).to(device) + ot_obj_values = torch.zeros([1]).to(device) + wd = 0 # Wasserstein distance + cood = self.cood.to(device) + for idx, points in enumerate(target_points): + if len(points) > 0: + # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood] + points = points / self.input_size * 2 - 1 if self.norm_cood else points + x = points[:, 0].unsqueeze_(1) # [#gt, 1] + y = points[:, 1].unsqueeze_(1) + x_dist = -2 * torch.matmul(x, cood) + x * x + cood * cood # [#gt, #cood] + y_dist = -2 * torch.matmul(y, cood) + y * y + cood * cood + y_dist.unsqueeze_(2) + x_dist.unsqueeze_(1) + dist = y_dist + x_dist + dist = dist.view((dist.size(0), -1)) # size of [#gt, #cood * #cood] + + source_prob = normed_pred_density[idx][0].view([-1]).detach() + target_prob = (torch.ones([len(points)]) / len(points)).to(device) + # use sinkhorn to solve OT, compute optimal beta. + P, log = sinkhorn(target_prob, source_prob, dist, self.reg, maxIter=self.num_of_iter_in_ot, log=True) + beta = log["beta"] # size is the same as source_prob: [#cood * #cood] + ot_obj_values += torch.sum(normed_pred_density[idx] * beta.view([1, self.output_size, self.output_size])) + # compute the gradient of OT loss to predicted density (pred_density). + # im_grad = beta / source_count - < beta, source_density> / (source_count)^2 + source_density = pred_density[idx][0].view([-1]).detach() + source_count = source_density.sum() + gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#cood * #cood] + gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1 + gradient = gradient_1 - gradient_2 + gradient = gradient.detach().view([1, self.output_size, self.output_size]) + # Define loss = . The gradient of loss w.r.t predicted density is im_grad. + loss += torch.sum(pred_density[idx] * gradient) + wd += torch.sum(dist * P).item() + + return loss, wd, ot_obj_values + + +class DMLoss(nn.Module): + def __init__( + self, + input_size: int, + reduction: int, + norm_cood: bool = False, + weight_ot: float = 0.1, + weight_tv: float = 0.01, + **kwargs: Any + ) -> None: + super().__init__() + self.ot_loss = OTLoss(input_size, reduction, norm_cood, **kwargs) + self.tv_loss = nn.L1Loss(reduction="none") + self.count_loss = nn.L1Loss(reduction="mean") + self.weight_ot = weight_ot + self.weight_tv = weight_tv + + @autocast(enabled=True, dtype=torch.float32) # avoid numerical instability + def forward(self, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: + target_density = _reshape_density(target_density, reduction=self.ot_loss.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density + assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}" + + pred_count = pred_density.view(pred_density.shape[0], -1).sum(dim=1) + normed_pred_density = pred_density / (pred_count.view(-1, 1, 1, 1) + EPS) + target_count = torch.tensor([len(p) for p in target_points], dtype=torch.float32).to(target_density.device) + normed_target_density = target_density / (target_count.view(-1, 1, 1, 1) + EPS) + + ot_loss, _, _ = self.ot_loss(pred_density, normed_pred_density, target_points) + + tv_loss = (self.tv_loss(normed_pred_density, normed_target_density).sum(dim=(1, 2, 3)) * target_count).mean() + + count_loss = self.count_loss(pred_count, target_count) + + loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + count_loss + + loss_info = { + "loss": loss.detach(), + "ot_loss": ot_loss.detach(), + "tv_loss": tv_loss.detach(), + "count_loss": count_loss.detach(), + } + + return loss, loss_info diff --git a/losses/utils.py b/losses/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..d1738455ad2667136ebf3ebf427932d23d418da6 --- /dev/null +++ b/losses/utils.py @@ -0,0 +1,9 @@ +from torch import Tensor + + +def _reshape_density(density: Tensor, reduction: int) -> Tensor: + assert len(density.shape) == 4, f"Expected 4D (B, 1, H, W) tensor, got {density.shape}" + assert density.shape[1] == 1, f"Expected 1 channel, got {density.shape[1]}" + assert density.shape[2] % reduction == 0, f"Expected height to be divisible by {reduction}, got {density.shape[2]}" + assert density.shape[3] % reduction == 0, f"Expected width to be divisible by {reduction}, got {density.shape[3]}" + return density.reshape(density.shape[0], 1, density.shape[2] // reduction, reduction, density.shape[3] // reduction, reduction).sum(dim=(-1, -3)) diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e49988f77a178abd264d211c298544835f3f24da --- /dev/null +++ b/main.py @@ -0,0 +1,75 @@ +import argparse +import os +from custom.clip_ebc_onnx import ClipEBCOnnx + +def parse_args(): + parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)') + parser.add_argument('--image', required=True, help='Path to input image') + parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model') + parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'], + default='none', help='Visualization type') + parser.add_argument('--save', action='store_true', + help='Save visualization results') + parser.add_argument('--output-dir', default='results', + help='Directory to save results') + + # 시각화 관련 매개변수 + parser.add_argument('--alpha', type=float, default=0.5, + help='Alpha value for density map') + parser.add_argument('--dot-size', type=int, default=20, + help='Dot size for dot visualization') + parser.add_argument('--sigma', type=float, default=1, + help='Sigma value for Gaussian filter') + parser.add_argument('--percentile', type=float, default=97, + help='Percentile threshold for dot visualization') + + + + return parser.parse_args() + +def main(): + args = parse_args() + + # 모델 초기화 - ONNX 버전 + model = ClipEBCOnnx( + onnx_model_path=args.model + ) + + # 출력 디렉토리 생성 + if args.save: + os.makedirs(args.output_dir, exist_ok=True) + + # 예측 수행 + count = model.predict(args.image) + print(f"예측된 군중 수: {count:.2f}") + + # 시각화 + if args.visualize in ['density', 'all']: + save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None + fig, density_map = model.visualize_density_map( + alpha=args.alpha, + save=args.save, + save_path=save_path + ) + + if args.visualize in ['dots', 'all']: + save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None + canvas, dot_map = model.visualize_dots( + dot_size=args.dot_size, + sigma=args.sigma, + percentile=args.percentile, + save=args.save, + save_path=save_path + ) + + # matplotlib figure 닫기 (메모리 누수 방지) + if args.visualize in ['density', 'all']: + import matplotlib.pyplot as plt + plt.close(fig) + + if args.visualize in ['dots', 'all']: + import matplotlib.pyplot as plt + plt.close(canvas.figure) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a0df8a8de23a23443668e7c8104479a82a0c3cd4 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,49 @@ +from typing import List, Tuple, Optional, Any, Union + +from .model import _classifier, _regressor, Classifier, Regressor +from .clip import _clip_ebc, CLIP_EBC +import assets + +clip_names = ["resnet50", "resnet50x4", "resnet50x16", "resnet50x64", "resnet101", "vit_b_16", "vit_b_32", "vit_l_14"] + + +def get_model( + backbone: str, + input_size: int, + reduction: int, + bins: Optional[List[Tuple[float, float]]] = None, + anchor_points: Optional[List[float]] = None, + **kwargs: Any, +) -> Union[Regressor, Classifier, CLIP_EBC]: + backbone = backbone.lower() + if "clip" in backbone: + backbone = backbone[5:] + assert backbone in clip_names, f"Expected backbone to be in {clip_names}, got {backbone}" + return _clip_ebc( + backbone=backbone, + input_size=input_size, + reduction=reduction, + bins=bins, + anchor_points=anchor_points, + **kwargs + ) + elif bins is None and anchor_points is None: + return _regressor( + backbone=backbone, + input_size=input_size, + reduction=reduction, + ) + else: + assert bins is not None and anchor_points is not None, f"Expected bins and anchor_points to be both None or not None, got {bins} and {anchor_points}" + return _classifier( + backbone=backbone, + input_size=input_size, + reduction=reduction, + bins=bins, + anchor_points=anchor_points, + ) + + +__all__ = [ + "get_model", +] diff --git a/models/clip/__init__.py b/models/clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2448575b7fa674c224f48027693bd7208fa7a9d0 --- /dev/null +++ b/models/clip/__init__.py @@ -0,0 +1,7 @@ +from .model import CLIP_EBC, _clip_ebc + + +__all__ = [ + "CLIP_EBC", + "_clip_ebc", +] diff --git a/models/clip/_clip/__init__.py b/models/clip/_clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..62dedfd41e9d311ea0f4175c953f0e9b364e427c --- /dev/null +++ b/models/clip/_clip/__init__.py @@ -0,0 +1,273 @@ +import torch +import os +from typing import Tuple, Optional, Any, Union +import json + +from .utils import tokenize, transform +from .prepare import prepare +from .text_encoder import CLIPTextEncoder +from .image_encoder import ModifiedResNet, VisionTransformer +from .model import CLIP + + +curr_dir = os.path.dirname(os.path.abspath(__file__)) + +clip_model_names = [ + "clip_resnet50", + "clip_resnet101", + "clip_resnet50x4", + "clip_resnet50x16", + "clip_resnet50x64", + "clip_vit_b_32", + "clip_vit_b_16", + "clip_vit_l_14", + "clip_vit_l_14_336px", +] + +clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names] +clip_text_encoder_names = [f"clip_text_encoder_{name[5:]}" for name in clip_model_names] + + +for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names: + model_weights_path = os.path.join(curr_dir, "weights", f"{name}.pth") + model_config_path = os.path.join(curr_dir, "configs", f"{name}.json") + if not os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")) or not os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")): + prepare() + break + + +for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names: + assert os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")), f"Missing {name}.pth in weights folder. Please run models/clip/prepare.py to download the weights." + assert os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")), f"Missing {name}.json in configs folder. Please run models/clip/prepare.py to download the configs." + + +def _clip(name: str, input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + with open(os.path.join(curr_dir, "configs", f"clip_{name}.json"), "r") as f: + config = json.load(f) + + model = CLIP( + embed_dim=config["embed_dim"], + # vision + image_resolution=config["image_resolution"], + vision_layers=config["vision_layers"], + vision_width=config["vision_width"], + vision_patch_size=config["vision_patch_size"], + # text + context_length=config["context_length"], + vocab_size=config["vocab_size"], + transformer_width=config["transformer_width"], + transformer_heads=config["transformer_heads"], + transformer_layers=config["transformer_layers"] + ) + state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_{name}.pth"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + if input_size is not None: + input_size = (input_size, input_size) if isinstance(input_size, int) else input_size + if name.startswith("vit"): + model.visual.adjust_pos_embed(*input_size) + + return model + + +def _resnet( + name: str, + reduction: int = 32, + features_only: bool = False, + out_indices: Optional[Tuple[int, ...]] = None, + **kwargs: Any +) -> ModifiedResNet: + with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f: + config = json.load(f) + model = ModifiedResNet( + layers=config["vision_layers"], + output_dim=config["embed_dim"], + input_resolution=config["image_resolution"], + width=config["vision_width"], + heads=config["vision_heads"], + features_only=features_only, + out_indices=out_indices, + reduction=reduction + ) + state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_keys) > 0 or len(unexpected_keys) > 0: + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + else: + print(f"All keys matched successfully.") + + return model + + +def _vit(name: str, features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: + with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f: + config = json.load(f) + model = VisionTransformer( + input_resolution=config["image_resolution"], + patch_size=config["vision_patch_size"], + output_dim=config["embed_dim"], + width=config["vision_width"], + layers=config["vision_layers"], + heads=config["vision_heads"], + features_only=features_only + ) + state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_keys) > 0 or len(unexpected_keys) > 0: + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + else: + print(f"All keys matched successfully.") + + if input_size is not None: + input_size = (input_size, input_size) if isinstance(input_size, int) else input_size + model.adjust_pos_embed(*input_size) + return model + + +def _text_encoder(name: str) -> CLIPTextEncoder: + with open(os.path.join(curr_dir, "configs", f"clip_text_encoder_{name}.json"), "r") as f: + config = json.load(f) + model = CLIPTextEncoder( + embed_dim=config["embed_dim"], + context_length=config["context_length"], + vocab_size=config["vocab_size"], + transformer_width=config["transformer_width"], + transformer_heads=config["transformer_heads"], + transformer_layers=config["transformer_layers"] + ) + state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_text_encoder_{name}.pth"), map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_keys) > 0 or len(unexpected_keys) > 0: + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + else: + print(f"All keys matched successfully.") + + return model + + + +# CLIP models +def resnet50_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("resnet50", input_size) + +def resnet101_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("resnet101", input_size) + +def resnet50x4_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("resnet50x4", input_size) + +def resnet50x16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("resnet50x16", input_size) + +def resnet50x64_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("resnet50x64", input_size) + +def vit_b_32_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("vit_b_32", input_size) + +def vit_b_16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("vit_b_16", input_size) + +def vit_l_14_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("vit_l_14", input_size) + +def vit_l_14_336px_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP: + return _clip("vit_l_14_336px", input_size) + + +# CLIP image encoders +def resnet50_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: + return _resnet("resnet50", features_only=features_only, out_indices=out_indices, **kwargs) + +def resnet101_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: + return _resnet("resnet101", features_only=features_only, out_indices=out_indices, **kwargs) + +def resnet50x4_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: + return _resnet("resnet50x4", features_only=features_only, out_indices=out_indices, **kwargs) + +def resnet50x16_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: + return _resnet("resnet50x16", features_only=features_only, out_indices=out_indices, **kwargs) + +def resnet50x64_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet: + return _resnet("resnet50x64", features_only=features_only, out_indices=out_indices, **kwargs) + +def vit_b_32_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: + return _vit("vit_b_32", features_only=features_only, input_size=input_size, **kwargs) + +def vit_b_16_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: + return _vit("vit_b_16", features_only=features_only, input_size=input_size, **kwargs) + +def vit_l_14_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: + return _vit("vit_l_14", features_only=features_only, input_size=input_size, **kwargs) + +def vit_l_14_336px_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer: + return _vit("vit_l_14_336px", features_only=features_only, input_size=input_size, **kwargs) + + +# CLIP text encoders +def resnet50_txt() -> CLIPTextEncoder: + return _text_encoder("resnet50") + +def resnet101_txt() -> CLIPTextEncoder: + return _text_encoder("resnet101") + +def resnet50x4_txt() -> CLIPTextEncoder: + return _text_encoder("resnet50x4") + +def resnet50x16_txt() -> CLIPTextEncoder: + return _text_encoder("resnet50x16") + +def resnet50x64_txt() -> CLIPTextEncoder: + return _text_encoder("resnet50x64") + +def vit_b_32_txt() -> CLIPTextEncoder: + return _text_encoder("vit_b_32") + +def vit_b_16_txt() -> CLIPTextEncoder: + return _text_encoder("vit_b_16") + +def vit_l_14_txt() -> CLIPTextEncoder: + return _text_encoder("vit_l_14") + +def vit_l_14_336px_txt() -> CLIPTextEncoder: + return _text_encoder("vit_l_14_336px") + + +__all__ = [ + # utils + "tokenize", + "transform", + # clip models + "resnet50_clip", + "resnet101_clip", + "resnet50x4_clip", + "resnet50x16_clip", + "resnet50x64_clip", + "vit_b_32_clip", + "vit_b_16_clip", + "vit_l_14_clip", + "vit_l_14_336px_clip", + # clip image encoders + "resnet50_img", + "resnet101_img", + "resnet50x4_img", + "resnet50x16_img", + "resnet50x64_img", + "vit_b_32_img", + "vit_b_16_img", + "vit_l_14_img", + "vit_l_14_336px_img", + # clip text encoders + "resnet50_txt", + "resnet101_txt", + "resnet50x4_txt", + "resnet50x16_txt", + "resnet50x64_txt", + "vit_b_32_txt", + "vit_b_16_txt", + "vit_l_14_txt", + "vit_l_14_336px_txt", +] diff --git a/models/clip/_clip/blocks.py b/models/clip/_clip/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..a97d3a9bb6b6bcf78f4051d23eae96af8a88e296 --- /dev/null +++ b/models/clip/_clip/blocks.py @@ -0,0 +1,137 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from collections import OrderedDict +from typing import Optional, Iterable + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: Tensor = None): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: Tensor): + return self.resblocks(x) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) diff --git a/models/clip/_clip/bpe_simple_vocab_16e6.txt.gz b/models/clip/_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100755 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/models/clip/_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/models/clip/_clip/configs/clip_image_encoder_resnet101.json b/models/clip/_clip/configs/clip_image_encoder_resnet101.json new file mode 100644 index 0000000000000000000000000000000000000000..3740fc5e4059f451dbb86fbe76cb70c2768b3333 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_resnet101.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": [ + 3, + 4, + 23, + 3 + ], + "vision_width": 64, + "vision_patch_size": null, + "vision_heads": 32 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_resnet50.json b/models/clip/_clip/configs/clip_image_encoder_resnet50.json new file mode 100644 index 0000000000000000000000000000000000000000..b5c8f50ee0bc29ee5cdd9054e1d958ebb7ba36ce --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_resnet50.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 1024, + "image_resolution": 224, + "vision_layers": [ + 3, + 4, + 6, + 3 + ], + "vision_width": 64, + "vision_patch_size": null, + "vision_heads": 32 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_resnet50x16.json b/models/clip/_clip/configs/clip_image_encoder_resnet50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..6e02e97ce679b141efe5cfb43ac3004ec73abcc2 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_resnet50x16.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 768, + "image_resolution": 384, + "vision_layers": [ + 6, + 8, + 18, + 8 + ], + "vision_width": 96, + "vision_patch_size": null, + "vision_heads": 48 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_resnet50x4.json b/models/clip/_clip/configs/clip_image_encoder_resnet50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..d81bf5332aa5348d2fe31d07b2b18dc138bbcc42 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_resnet50x4.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 640, + "image_resolution": 288, + "vision_layers": [ + 4, + 6, + 10, + 6 + ], + "vision_width": 80, + "vision_patch_size": null, + "vision_heads": 40 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_resnet50x64.json b/models/clip/_clip/configs/clip_image_encoder_resnet50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..041b4d1d52cbbd8ca7edef18c17f82732003e1f5 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_resnet50x64.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 1024, + "image_resolution": 448, + "vision_layers": [ + 3, + 15, + 36, + 10 + ], + "vision_width": 128, + "vision_patch_size": null, + "vision_heads": 64 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_vit_b_16.json b/models/clip/_clip/configs/clip_image_encoder_vit_b_16.json new file mode 100644 index 0000000000000000000000000000000000000000..3dae670f2d405cdba5baf5389053e2fd7e724e45 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_vit_b_16.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 16, + "vision_heads": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_vit_b_32.json b/models/clip/_clip/configs/clip_image_encoder_vit_b_32.json new file mode 100644 index 0000000000000000000000000000000000000000..c84255edade36fa56d2a4ca8bb16b907da98a506 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_vit_b_32.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 32, + "vision_heads": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_vit_l_14.json b/models/clip/_clip/configs/clip_image_encoder_vit_l_14.json new file mode 100644 index 0000000000000000000000000000000000000000..c15df10a0e2836152dd377e8611367b1051441bd --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_vit_l_14.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 768, + "image_resolution": 224, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "vision_heads": 16 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_image_encoder_vit_l_14_336px.json b/models/clip/_clip/configs/clip_image_encoder_vit_l_14_336px.json new file mode 100644 index 0000000000000000000000000000000000000000..a5d8e90e5be455e544c514acd73f98525edf3697 --- /dev/null +++ b/models/clip/_clip/configs/clip_image_encoder_vit_l_14_336px.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 768, + "image_resolution": 336, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "vision_heads": 16 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_resnet101.json b/models/clip/_clip/configs/clip_resnet101.json new file mode 100644 index 0000000000000000000000000000000000000000..4b3823c87a478b0d0d18c2e594b498f967eb2b02 --- /dev/null +++ b/models/clip/_clip/configs/clip_resnet101.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": [ + 3, + 4, + 23, + 3 + ], + "vision_width": 64, + "vision_patch_size": null, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_resnet50.json b/models/clip/_clip/configs/clip_resnet50.json new file mode 100644 index 0000000000000000000000000000000000000000..131a9afbb1b6fb0df37ae7dbb7502a0b09574aff --- /dev/null +++ b/models/clip/_clip/configs/clip_resnet50.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "image_resolution": 224, + "vision_layers": [ + 3, + 4, + 6, + 3 + ], + "vision_width": 64, + "vision_patch_size": null, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_resnet50x16.json b/models/clip/_clip/configs/clip_resnet50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..1cb3b9c18d9fd398c9a3e24287a4274e19b2d25f --- /dev/null +++ b/models/clip/_clip/configs/clip_resnet50x16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "image_resolution": 384, + "vision_layers": [ + 6, + 8, + 18, + 8 + ], + "vision_width": 96, + "vision_patch_size": null, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_resnet50x4.json b/models/clip/_clip/configs/clip_resnet50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..94988302ae6c502e5d536b9a58223d1596943d3b --- /dev/null +++ b/models/clip/_clip/configs/clip_resnet50x4.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "image_resolution": 288, + "vision_layers": [ + 4, + 6, + 10, + 6 + ], + "vision_width": 80, + "vision_patch_size": null, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 640, + "transformer_heads": 10, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_resnet50x64.json b/models/clip/_clip/configs/clip_resnet50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..c9f5a234aad3e8077bc169f4e0307aea1d5f8726 --- /dev/null +++ b/models/clip/_clip/configs/clip_resnet50x64.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "image_resolution": 448, + "vision_layers": [ + 3, + 15, + 36, + 10 + ], + "vision_width": 128, + "vision_patch_size": null, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 1024, + "transformer_heads": 16, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_resnet101.json b/models/clip/_clip/configs/clip_text_encoder_resnet101.json new file mode 100644 index 0000000000000000000000000000000000000000..f38b9154cbc12f139457c6259be031388eceeebf --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_resnet101.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 512, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_resnet50.json b/models/clip/_clip/configs/clip_text_encoder_resnet50.json new file mode 100644 index 0000000000000000000000000000000000000000..b86ead4424071a8a51dbba3885ae008628c918c2 --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_resnet50.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 1024, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_resnet50x16.json b/models/clip/_clip/configs/clip_text_encoder_resnet50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..14933f5cf7dc7634f7b112921892c78ccf70f6d0 --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_resnet50x16.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 768, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_resnet50x4.json b/models/clip/_clip/configs/clip_text_encoder_resnet50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..13df0a51a74314338e1aa31a82182a70c6412fdb --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_resnet50x4.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 640, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 640, + "transformer_heads": 10, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_resnet50x64.json b/models/clip/_clip/configs/clip_text_encoder_resnet50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..39ace15d102cc7206dcc3c8117ac80e9cea8bc10 --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_resnet50x64.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 1024, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 1024, + "transformer_heads": 16, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_vit_b_16.json b/models/clip/_clip/configs/clip_text_encoder_vit_b_16.json new file mode 100644 index 0000000000000000000000000000000000000000..f38b9154cbc12f139457c6259be031388eceeebf --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_vit_b_16.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 512, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_vit_b_32.json b/models/clip/_clip/configs/clip_text_encoder_vit_b_32.json new file mode 100644 index 0000000000000000000000000000000000000000..f38b9154cbc12f139457c6259be031388eceeebf --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_vit_b_32.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 512, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_vit_l_14.json b/models/clip/_clip/configs/clip_text_encoder_vit_l_14.json new file mode 100644 index 0000000000000000000000000000000000000000..14933f5cf7dc7634f7b112921892c78ccf70f6d0 --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_vit_l_14.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 768, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_text_encoder_vit_l_14_336px.json b/models/clip/_clip/configs/clip_text_encoder_vit_l_14_336px.json new file mode 100644 index 0000000000000000000000000000000000000000..14933f5cf7dc7634f7b112921892c78ccf70f6d0 --- /dev/null +++ b/models/clip/_clip/configs/clip_text_encoder_vit_l_14_336px.json @@ -0,0 +1,8 @@ +{ + "embed_dim": 768, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_vit_b_16.json b/models/clip/_clip/configs/clip_vit_b_16.json new file mode 100644 index 0000000000000000000000000000000000000000..9bedfb8e54f8adf1a4e23a6408ad759b306d89a1 --- /dev/null +++ b/models/clip/_clip/configs/clip_vit_b_16.json @@ -0,0 +1,12 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 16, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_vit_b_32.json b/models/clip/_clip/configs/clip_vit_b_32.json new file mode 100644 index 0000000000000000000000000000000000000000..57b94e5098baca7446d14e25da63f21d13809fb0 --- /dev/null +++ b/models/clip/_clip/configs/clip_vit_b_32.json @@ -0,0 +1,12 @@ +{ + "embed_dim": 512, + "image_resolution": 224, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 32, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_vit_l_14.json b/models/clip/_clip/configs/clip_vit_l_14.json new file mode 100644 index 0000000000000000000000000000000000000000..1619202b5d1eb6c749b9fba69206df6cdd75e814 --- /dev/null +++ b/models/clip/_clip/configs/clip_vit_l_14.json @@ -0,0 +1,12 @@ +{ + "embed_dim": 768, + "image_resolution": 224, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/configs/clip_vit_l_14_336px.json b/models/clip/_clip/configs/clip_vit_l_14_336px.json new file mode 100644 index 0000000000000000000000000000000000000000..b3afa37914478cc7465086cd2e075fa8c7cccbe3 --- /dev/null +++ b/models/clip/_clip/configs/clip_vit_l_14_336px.json @@ -0,0 +1,12 @@ +{ + "embed_dim": 768, + "image_resolution": 336, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "context_length": 77, + "vocab_size": 49408, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12 +} \ No newline at end of file diff --git a/models/clip/_clip/image_encoder.py b/models/clip/_clip/image_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..50b342d054986a7d66f1b711dd2eb9f646482608 --- /dev/null +++ b/models/clip/_clip/image_encoder.py @@ -0,0 +1,225 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from einops import rearrange +from typing import Tuple, Union, Any, List, Iterable, Optional + +from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + def __init__( + self, + layers: Tuple[int, int, int, int], + output_dim: int, + input_resolution: int = 224, + width: int = 64, + heads: int = 8, + features_only: bool = False, + out_indices: Optional[Iterable[int]] = None, + reduction: int = 32, + **kwargs: Any, + ) -> None: + super().__init__() + input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution + assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" + self.input_resolution = input_resolution + self.downsampling_rate = 32 # the rate at which the input is downsampled by the network + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2) + + self.features_only = features_only + if features_only: + self.out_indices = out_indices if out_indices is not None else range(5) + self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices + self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort + assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages" + self.channels = width * 32 # the ResNet feature dimension + else: + self.out_indices = None + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim) + self.channels = output_dim + + self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate + self.clip_embed_dim = output_dim + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def _stem(self, x: Tensor) -> Tensor: + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]: + x = x.type(self.conv1.weight.dtype) + x = self._stem(x) + + feats = [x] if self.features_only and 0 in self.out_indices else [] + + x = self.layer1(x) + if self.features_only and 1 in self.out_indices: + feats.append(x) + + x = self.layer2(x) + if self.features_only and 2 in self.out_indices: + feats.append(x) + + x = self.layer3(x) + if self.features_only and 3 in self.out_indices: + feats.append(x) + + x = self.layer4(x) + if self.features_only and 4 in self.out_indices: + feats.append(x) + + if self.features_only: + if len(self.out_indices) == 1: + return feats[0] + else: + return feats + else: + x = self.attnpool(x) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + input_resolution: Union[int, Tuple[int, int]], + patch_size: Union[int, Tuple[int, int]], + output_dim: int, + width: int, + layers: int, + heads: int, + features_only: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution + patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" + assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}" + assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid." + assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}" + self.input_resolution = input_resolution + self.patch_size = patch_size + self.downsampling_rate = patch_size[0] + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.num_patches_h = int(input_resolution[0] // patch_size[0]) + self.num_patches_w = int(input_resolution[1] // patch_size[1]) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + self.ln_post = LayerNorm(width) + + self.features_only = features_only # if True, return the final patches instead of the CLS token + if features_only: + self.channels = width + else: + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + self.channels = output_dim + + self.reduction = patch_size[0] + self.clip_embed_dim = output_dim + + def adjust_pos_embed(self, h: int, w: int) -> None: + """ + Permanently adjust the size of the positional embedding matrix. + + Args: + h: the height of the original input image. + w: the width of the original input image. + """ + assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}" + if self.input_resolution[0] != h or self.input_resolution[1] != w: + new_num_patches_h = int(h // self.patch_size[0]) + new_num_patches_w = int(w // self.patch_size[1]) + positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension + positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension + positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") + self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)) + self.input_resolution = (h, w) + self.num_patches_h = new_num_patches_h + self.num_patches_w = new_num_patches_w + + def _interpolate_pos_embed(self, h: int, w: int) -> Tensor: + """ + Interpolate the positional embedding matrix to match the size of the input image. + + Args: + h: the required number of patches along the height dimension. + w: the required number of patches along the width dimension. + """ + if h == self.num_patches_h and w == self.num_patches_w: + return self.positional_embedding + else: + positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension + positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension + positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") + positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0) + return positional_embedding + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) # shape = [*, width, grid, grid] + num_patches_h, num_patches_w = x.shape[-2:] + + positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype) + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x + ], dim=1) + x = x + positional_embedding + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_post(x) + + if self.features_only: + x = x[:, 1:, :] # remove the CLS token + x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w) + else: + x = x[:, 0, :] + x = x @ self.proj + return x diff --git a/models/clip/_clip/model.py b/models/clip/_clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..80f3939d4fc42d08863682c83eb2cc1fa0e308c5 --- /dev/null +++ b/models/clip/_clip/model.py @@ -0,0 +1,214 @@ +import torch +from torch import nn +import numpy as np + +from typing import Tuple, Union + +from .image_encoder import ModifiedResNet, VisionTransformer +from .text_encoder import LayerNorm, Transformer + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.image_resolution = image_resolution + self.vision_layers = vision_layers + self.vision_width = vision_width + self.vision_patch_size = vision_patch_size + self.context_length = context_length + self.vocab_size = vocab_size + self.transformer_width = transformer_width + self.transformer_heads = transformer_heads + self.transformer_layers = transformer_layers + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + features_only=False, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + features_only=False, + ) + self.vision_heads = vision_heads + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() diff --git a/models/clip/_clip/prepare.py b/models/clip/_clip/prepare.py new file mode 100755 index 0000000000000000000000000000000000000000..182cae6103024a5e9142fc8d419c858240bfd3ff --- /dev/null +++ b/models/clip/_clip/prepare.py @@ -0,0 +1,95 @@ +# Prepare the models to speed up loading them later +import torch +from torch import nn, Tensor +import os +from tqdm import tqdm +import json + +from .utils import load + + +model_name_map = { + "RN50": "resnet50", + "RN101": "resnet101", + "RN50x4": "resnet50x4", + "RN50x16": "resnet50x16", + "RN50x64": "resnet50x64", + "ViT-B/32": "vit_b_32", + "ViT-B/16": "vit_b_16", + "ViT-L/14": "vit_l_14", + "ViT-L/14@336px": "vit_l_14_336px", +} + + +class CLIPTextEncoderTemp(nn.Module): + def __init__( + self, + clip: nn.Module, + ) -> None: + super().__init__() + self.context_length = clip.context_length + self.vocab_size = clip.vocab_size + self.dtype = clip.dtype + self.token_embedding = clip.token_embedding + self.positional_embedding = clip.positional_embedding + self.transformer = clip.transformer + self.ln_final = clip.ln_final + self.text_projection = clip.text_projection + + def forward(self, text: Tensor) -> None: + pass + + +def prepare() -> None: + print("Preparing CLIP models...") + curr_dir = os.path.dirname(os.path.abspath(__file__)) + weight_dir = os.path.join(curr_dir, "weights") + config_dir = os.path.join(curr_dir, "configs") + os.makedirs(weight_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + device = torch.device("cpu") + + for model_name in tqdm(["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"]): + model = load(model_name, device=device).to(device) + image_encoder = model.visual.to(device) + text_encoder = CLIPTextEncoderTemp(model).to(device) + torch.save(model.state_dict(), os.path.join(weight_dir, f"clip_{model_name_map[model_name]}.pth")) + torch.save(image_encoder.state_dict(), os.path.join(weight_dir, f"clip_image_encoder_{model_name_map[model_name]}.pth")) + torch.save(text_encoder.state_dict(), os.path.join(weight_dir, f"clip_text_encoder_{model_name_map[model_name]}.pth")) + model_config = { + "embed_dim": model.embed_dim, + # vision + "image_resolution": model.image_resolution, + "vision_layers": model.vision_layers, + "vision_width": model.vision_width, + "vision_patch_size": model.vision_patch_size, + # text + "context_length": model.context_length, + "vocab_size": model.vocab_size, + "transformer_width": model.transformer_width, + "transformer_heads": model.transformer_heads, + "transformer_layers": model.transformer_layers, + } + image_encoder_config = { + "embed_dim": model.embed_dim, + "image_resolution": model.image_resolution, + "vision_layers": model.vision_layers, + "vision_width": model.vision_width, + "vision_patch_size": model.vision_patch_size, + "vision_heads": model.vision_heads, + } + text_encoder_config = { + "embed_dim": model.embed_dim, + "context_length": model.context_length, + "vocab_size": model.vocab_size, + "transformer_width": model.transformer_width, + "transformer_heads": model.transformer_heads, + "transformer_layers": model.transformer_layers, + } + with open(os.path.join(config_dir, f"clip_{model_name_map[model_name]}.json"), "w") as f: + json.dump(model_config, f, indent=4) + with open(os.path.join(config_dir, f"clip_image_encoder_{model_name_map[model_name]}.json"), "w") as f: + json.dump(image_encoder_config, f, indent=4) + with open(os.path.join(config_dir, f"clip_text_encoder_{model_name_map[model_name]}.json"), "w") as f: + json.dump(text_encoder_config, f, indent=4) + print("Done!") diff --git a/models/clip/_clip/simple_tokenizer.py b/models/clip/_clip/simple_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..341fb619d6f3b0cd648fd8d938fefe8bff780bdc --- /dev/null +++ b/models/clip/_clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/models/clip/_clip/text_encoder.py b/models/clip/_clip/text_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..64b66ee217a0f538c5dfa712b5e518290d26f706 --- /dev/null +++ b/models/clip/_clip/text_encoder.py @@ -0,0 +1,53 @@ +import torch +from torch import nn, Tensor + +from .blocks import LayerNorm, Transformer + + +class CLIPTextEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ) -> None: + super().__init__() + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.transformer.resblocks[0].attn.in_proj_weight.dtype + + def forward(self, text: Tensor): + x = self.token_embedding(text).type(self.dtype) + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x diff --git a/models/clip/_clip/utils.py b/models/clip/_clip/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..b1e1f53897425bb55f274dc62247d455b88ff2ae --- /dev/null +++ b/models/clip/_clip/utils.py @@ -0,0 +1,249 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List +from pkg_resources import packaging + +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import torch + +from typing import List, Union +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + + + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/models/clip/model.py b/models/clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..51db2eef829a02898b955443bb9ac667e53f719e --- /dev/null +++ b/models/clip/model.py @@ -0,0 +1,277 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F +import numpy as np +import os +import math +from typing import List, Tuple, Union, Optional + +from . import _clip +from ..utils import _init_weights, make_resnet_layers, Bottleneck, BasicBlock +from .utils import format_count + +curr_dir = os.path.abspath(os.path.dirname(__file__)) + + +# resnet50: reduction, channels, embed_dim = 32, 2048, 1024 +# resnet101: reduction, channels, embed_dim = 32, 2048, 512 +# resnet50x4: reduction, channels, embed_dim = 32, 2560, 640 +# resnet50x16: reduction, channels, embed_dim = 32, 3072, 768 +# resnet50x64: reduction, channels, embed_dim = 32, 4096, 1024 +# vit_b_32: reduction, channels, embed_dim = 32, 768, 512 +# vit_b_16: reduction, channels, embed_dim = 16, 768, 512 +# vit_l_14: reduction, channels, embed_dim = 14, 1024, 768 +# vit_l_14_336px: reduction, channels, embed_dim = 14, 1024, 768 + +resnet_backbones = ["resnet50", "resnet101", "resnet50x4", "resnet50x16", "resnet50x64"] +vit_backbones = ["vit_b_16", "vit_b_32", "vit_l_14", "vit_l_14_336px"] + + +class CLIP_EBC(nn.Module): + def __init__( + self, + backbone: str, + bins: List[Tuple[float, float]], + anchor_points: List[float], + reduction: Optional[int] = None, + freeze_text_encoder: bool = True, + prompt_type: str = "number", + input_size: Optional[int] = None, + num_vpt: Optional[int] = None, + deep_vpt: Optional[bool] = None, + vpt_drop: Optional[float] = None, + decoder_block: Optional[nn.Module] = None, + decoder_cfg: Optional[List[Union[str, int]]] = None, + ) -> None: + super().__init__() + assert backbone in resnet_backbones + vit_backbones, f"Backbone should be in {resnet_backbones + vit_backbones}, got {backbone}" + self.backbone = backbone + + # Image encoder + if backbone in resnet_backbones: + self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, out_indices=(-1,), reduction=reduction) + + else: + assert input_size is not None, "Expected input_size to be an integer, got None." + assert num_vpt is not None, "Expected num_vpt to be an integer, got None." + assert deep_vpt is not None, "Expected deep_vpt to be a boolean, got None." + assert vpt_drop is not None, "Expected vpt_drop to be a float, got None." + + self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, input_size=input_size) + self.image_encoder_depth = len(self.image_encoder.transformer.resblocks) + + # Use VPT. Freeze the image encoder. + for param in self.image_encoder.parameters(): + param.requires_grad = False + + self.num_vpt = num_vpt + self.deep_vpt = deep_vpt + + patch_size = self.image_encoder.patch_size[0] + val = math.sqrt(6. / float(3 * patch_size + self.image_encoder.channels)) + + for idx in range(self.image_encoder_depth if self.deep_vpt else 1): + setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.image_encoder.channels))) + nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val) + setattr(self, f"vpt_drop_{idx}", nn.Dropout(vpt_drop) if vpt_drop > 0 else nn.Identity()) + + self.encoder_reduction = self.image_encoder.reduction + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = self.image_encoder.channels + self.clip_embed_dim = self.image_encoder.clip_embed_dim + + if decoder_cfg is not None: + assert decoder_block is not None, "Expected decoder_block to be a nn.Module, got None." + self.image_decoder = make_resnet_layers(decoder_block, decoder_cfg, in_channels=self.channels, expansion=1, dilation=1) + self.image_decoder.apply(_init_weights) + self.channels = decoder_cfg[-1] + else: + self.image_decoder = nn.Identity() + + if self.channels != self.clip_embed_dim: + self.projection = nn.Conv2d(in_channels=self.channels, out_channels=self.clip_embed_dim, kernel_size=1) + self.projection.apply(_init_weights) + else: + self.projection = nn.Identity() + + # Text encoder + assert prompt_type in ["number", "word"], f"Expected prompt_type to be 'number' or 'word', got {prompt_type}" + self.prompt_type = prompt_type + self.text_encoder = getattr(_clip, f"{backbone}_txt")() + self.freeze_text_encoder = freeze_text_encoder + if self.freeze_text_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + self.bins = bins + self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1) + + self._get_text_prompts() + self._tokenize_text_prompts() + + if self.freeze_text_encoder: + self._extract_text_features() + else: + self.text_features = None + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) + + def _get_text_prompts(self) -> None: + bins = [b[0] if b[0] == b[1] else b for b in self.bins] + self.text_prompts = [format_count(b, self.prompt_type) for b in bins] + print(f"Initialized model with text prompts: {self.text_prompts}") + + def _tokenize_text_prompts(self) -> None: + self.text_prompts = _clip.tokenize(self.text_prompts) + + def _extract_text_features(self) -> None: + with torch.no_grad(): + self.text_features = self.text_encoder(self.text_prompts) + + def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor: + if not self.deep_vpt: + assert layer == 0, f"Expected layer to be 0 when using Shallow Visual Prompt Tuning, got {layer}" + + vpt = getattr(self, f"vpt_{layer}").to(device) + vpt = vpt.unsqueeze(0).expand(batch_size, -1, -1) + vpt = getattr(self, f"vpt_drop_{layer}")(vpt) + vpt = vpt.permute(1, 0, 2) # (num_vpt, batch_size, hidden_dim) + assert vpt.shape[1] == batch_size, f"Expected the VPT to have the shape [L_vis B C], got {vpt.shape}." + return vpt + + def _forward_vpt(self, x: Tensor) -> Tuple[Tensor]: + device = x.device + batch_size, _, height, width = x.shape + num_h_patches, num_w_patches = height // self.image_encoder.patch_size[0], width // self.image_encoder.patch_size[1] + + image_features = self.image_encoder.conv1(x) + image_features = image_features.reshape(batch_size, image_features.shape[1], -1) + image_features = image_features.permute(0, 2, 1) # (B, num_patches, C) + image_features = torch.cat([ + self.image_encoder.class_embedding + torch.zeros(batch_size, 1, image_features.shape[-1], dtype=image_features.dtype, device=device), + image_features, + ], dim=1) # (B, num_patches + 1, C) + + pos_embedding = self.image_encoder._interpolate_pos_embed(num_h_patches, num_w_patches) + image_features = image_features + pos_embedding + image_features = self.image_encoder.ln_pre(image_features) + image_features = image_features.permute(1, 0, 2) # (num_patches + 1, B, C) + assert image_features.shape[0] == num_h_patches * num_w_patches + 1 and image_features.shape[1] == batch_size, f"Expected image_features to have shape [num_patches + 1, B, C], got {image_features.shape}." + + vpt = self._prepare_vpt(0, batch_size, device) + for idx in range(self.image_encoder_depth): + # assemble + image_features = torch.cat([ + image_features[:1, :, :], # CLS token + vpt, + image_features[1:, :, :], + ], dim=0) + + # transformer + image_features = self.image_encoder.transformer.resblocks[idx](image_features) + + # disassemble + if idx < self.image_encoder_depth - 1: + if self.deep_vpt: + vpt = self._prepare_vpt(idx + 1, batch_size, device) + else: + vpt = image_features[1: (self.num_vpt + 1), :, :] + + image_features = torch.cat([ + image_features[:1, :, :], # CLS token + image_features[(self.num_vpt + 1):, :, :], + ], dim=0) + + image_features = image_features.permute(1, 0, 2) # (B, num_patches + 1, C) + image_features = self.image_encoder.ln_post(image_features) + image_features = image_features[:, 1:, :].permute(0, 2, 1) # (B, C, num_patches) + image_features = image_features.reshape(batch_size, -1, num_h_patches, num_w_patches) + return image_features + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: + device = x.device + + x = self.image_encoder(x) if self.backbone in resnet_backbones else self._forward_vpt(x) + if self.reduction != self.encoder_reduction: + # print("Before interpolation:", x.shape) + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + # print("After interpolation:", x.shape) + # for name, param in self.image_decoder.named_parameters(): + # print(f"Decoder parameter {name}:") + # print(f"- Shape: {param.shape}") + # print(f"- Device: {param.device}") + # print(f"- Requires grad: {param.requires_grad}") + x = self.image_decoder(x) + x = self.projection(x) + + image_features = x.permute(0, 2, 3, 1) # shape (B, H, W, C) + text_features = self.text_encoder(self.text_prompts.to(device)) if self.text_features is None else self.text_features.to(device) # shape (N, C) + + image_features = F.normalize(image_features, p=2, dim=-1) + text_features = F.normalize(text_features, p=2, dim=-1) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits = logit_scale * image_features @ text_features.t() # (B, H, W, N), logits per image + logits = logits.permute(0, 3, 1, 2) # (B, N, H, W) + + probs = logits.softmax(dim=1) + exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # (B, 1, H, W) + + if self.training: + return logits, exp + else: + return exp + + +def _clip_ebc( + backbone: str, + bins: List[Tuple[float, float]], + anchor_points: List[float], + reduction: Optional[int] = None, + freeze_text_encoder: bool = True, + prompt_type: str = "number", + input_size: Optional[int] = None, + num_vpt: Optional[int] = None, + deep_vpt: Optional[bool] = None, + vpt_drop: Optional[float] = None, + decoder_block: Optional[nn.Module] = None, + decoder_cfg: Optional[List[Union[str, int]]] = None +) -> CLIP_EBC: + if backbone in resnet_backbones: + decoder_block = Bottleneck + if decoder_cfg is None: + if backbone == "resnet50": + decoder_cfg = [2048] + elif backbone == "resnet50x4": + decoder_cfg = [1280] + elif backbone == "resnet50x16": + decoder_cfg = [1536] + elif backbone == "resnet50x64": + decoder_cfg = [2048] + else: # backbone == "resnet101" + decoder_cfg = [2048, 1024] + else: + decoder_block = BasicBlock + if decoder_cfg is None: + if backbone == "vit_b_16": + decoder_cfg = [768] + elif backbone == "vit_b_32": + decoder_cfg = [768] + else: # backbone == "vit_l_14" + decoder_cfg = [1024] + + return CLIP_EBC( + backbone=backbone, + bins=bins, + anchor_points=anchor_points, + reduction=reduction, + freeze_text_encoder=freeze_text_encoder, + prompt_type=prompt_type, + input_size=input_size, + num_vpt=num_vpt, + deep_vpt=deep_vpt, + vpt_drop=vpt_drop, + decoder_block=decoder_block, + decoder_cfg=decoder_cfg, + ) diff --git a/models/clip/utils.py b/models/clip/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..f434ecb0fce4cf800f103472235849f0e134229c --- /dev/null +++ b/models/clip/utils.py @@ -0,0 +1,40 @@ +from typing import Union, Tuple + + +num_to_word = { + "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", + "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen", + "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine", + "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine", + "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine", + "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine", + "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine", + "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine", + "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine", + "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine", + "100": "one hundred", "200": "two hundred", "300": "three hundred", "400": "four hundred", "500": "five hundred", "600": "six hundred", "700": "seven hundred", "800": "eight hundred", "900": "nine hundred", + "1000": "one thousand" +} + + +def num2word(num: Union[int, str]) -> str: + """ + Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc. + """ + num = str(int(num)) + return num_to_word.get(num, num) + + +def format_count(count: Union[float, Tuple[float, float]], prompt_type: str = "word") -> str: + if count == 0: + return "There is no person." if prompt_type == "word" else "There is 0 person." + elif count == 1: + return "There is one person." if prompt_type == "word" else "There is 1 person." + elif isinstance(count, (int, float)): + return f"There are {num2word(int(count))} people." if prompt_type == "word" else f"There are {int(count)} people." + elif count[1] == float("inf"): + return f"There are more than {num2word(int(count[0]))} people." if prompt_type == "word" else f"There are more than {int(count[0])} people." + else: # count is a tuple of finite numbers + left, right = int(count[0]), int(count[1]) + left, right = num2word(left), num2word(right) if prompt_type == "word" else left, right + return f"There are between {left} and {right} people." diff --git a/models/encoder/__init__.py b/models/encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5e3645634764165ff0e7a0faab5633c1bb6f2d35 --- /dev/null +++ b/models/encoder/__init__.py @@ -0,0 +1,10 @@ +from .vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn +from .vit import vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14 +from .timm_models import _timm_encoder + + +__all__ = [ + "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn", + "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14", + "_timm_encoder", +] diff --git a/models/encoder/timm_models.py b/models/encoder/timm_models.py new file mode 100755 index 0000000000000000000000000000000000000000..803662b924d595b40088b57ee8b90483398a8173 --- /dev/null +++ b/models/encoder/timm_models.py @@ -0,0 +1,54 @@ +from timm import create_model, list_models +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional + +from warnings import warn + + +class TIMMEncoder(nn.Module): + def __init__( + self, + backbone: str, + reduction: Optional[int] = None, + ) -> None: + super().__init__() + assert backbone in list_models(), f"Backbone {backbone} not available in timm" + encoder = create_model(backbone, pretrained=True, features_only=True, out_indices=[-1]) + encoder_reduction = encoder.feature_info.reduction()[-1] + + if reduction <= 16: + if "resnet" in backbone: + if "resnet18" in backbone or "resnet34" in backbone: + encoder.layer4[0].conv1.stride = (1, 1) + encoder.layer4[0].downsample[0].stride = (1, 1) + else: + encoder.layer4[0].conv2.stride = (1, 1) + encoder.layer4[0].downsample[0].stride = (1, 1) + encoder_reduction = encoder_reduction // 2 + + elif "mobilenetv2" in backbone: + encoder.blocks[5][0].conv_dw.stride = (1, 1) + encoder_reduction = encoder_reduction // 2 + + elif "densenet" in backbone: + encoder.features_transition3.pool = nn.Identity() + encoder_reduction = encoder_reduction // 2 + + else: + warn(f"Reduction for {backbone} not handled. Using default reduction of {encoder_reduction}") + + self.encoder = encoder + self.encoder_reduction = encoder_reduction + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = self.encoder.feature_info.channels()[-1] + + def forward(self, x: Tensor) -> Tensor: + x = self.encoder(x)[-1] + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + return x + + +def _timm_encoder(backbone: str, reduction: Optional[int] = None) -> TIMMEncoder: + return TIMMEncoder(backbone, reduction) diff --git a/models/encoder/vgg.py b/models/encoder/vgg.py new file mode 100755 index 0000000000000000000000000000000000000000..e60c45b22add175f0b271685153ee622240c5c4b --- /dev/null +++ b/models/encoder/vgg.py @@ -0,0 +1,69 @@ +from torch import nn, Tensor +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url +from typing import Optional + +from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls + + +class VGG(nn.Module): + def __init__( + self, + features: nn.Module, + reduction: Optional[int] = None, + ) -> None: + super().__init__() + self.features = features + self.encoder_reduction = 16 + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = 512 + + def forward(self, x: Tensor) -> Tensor: + x = self.features(x) + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + return x + + +def _load_weights(model: VGG, url: str) -> VGG: + state_dict = load_state_dict_from_url(url) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Loading pre-trained weights") + if len(missing_keys) > 0: + print(f"Missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"Unexpected keys: {unexpected_keys}") + return model + + +def vgg11(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg11"]) + +def vgg11_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg11_bn"]) + +def vgg13(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg13"]) + +def vgg13_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg13_bn"]) + +def vgg16(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg16"]) + +def vgg16_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg16_bn"]) + +def vgg19(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg19"]) + +def vgg19_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg19_bn"]) diff --git a/models/encoder/vit.py b/models/encoder/vit.py new file mode 100755 index 0000000000000000000000000000000000000000..92d7cf895f83b23d36a47256a87af0c925ebb561 --- /dev/null +++ b/models/encoder/vit.py @@ -0,0 +1,526 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, NamedTuple, Optional, Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url +from einops import rearrange + +from ..utils import Conv2dNormActivation, MLP +from ..utils import _log_api_usage_once + + +weights = { + "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", + "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", + "vit_h_14": "https://download.pytorch.org/models/vit_h_14-6kbcf7eb.pth", +} + + +class ConvStemConfig(NamedTuple): + out_channels: int + kernel_size: int + stride: int + norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d + activation_layer: Callable[..., nn.Module] = nn.ReLU + + +class MLPBlock(MLP): + """Transformer MLP block.""" + + _version = 2 + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053 + for i in range(2): + for type in ["weight", "bias"]: + old_key = f"{prefix}linear_{i+1}.{type}" + new_key = f"{prefix}{3*i}.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(x, x, x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + def __init__( + self, + num_h_patches: int, + num_w_patches: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_h_patches = num_h_patches + self.num_w_patches = num_w_patches + + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + seq_length = num_h_patches * num_w_patches + 1 # +1 for the class token + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def _get_pos_embedding(self, n_h: int, n_w: int) -> Tensor: + if n_h == self.num_h_patches and n_w == self.num_w_patches: + return self.pos_embedding + else: + pos_embedding = self.pos_embedding[:, 1:, :] + pos_embedding = rearrange(pos_embedding, "1 (h w) d -> 1 d h w", h=self.num_h_patches, w=self.num_w_patches) + pos_embedding = F.interpolate(pos_embedding, size=(n_h, n_w), mode="bicubic") + pos_embedding = rearrange(pos_embedding, "1 d h w -> 1 (h w) d") + return torch.cat([self.pos_embedding[:, :1, :], pos_embedding], dim=1) + + def forward(self, input: Tensor, n_h: int, n_w: int) -> Tensor: + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self._get_pos_embedding(n_h, n_w) + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as a feature extractor.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + # num_classes: int = 1000, # No need for the classification head as we only need the features + reduction: Optional[int] = None, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__() + _log_api_usage_once(self) + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + # self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + Conv2dNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) + ) + self.conv_proj: nn.Module = seq_proj + else: + self.conv_proj = nn.Conv2d( + in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + image_size // patch_size, + image_size // patch_size, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + # heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + # if representation_size is None: + # heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + # else: + # heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + # heads_layers["act"] = nn.Tanh() + # heads_layers["head"] = nn.Linear(representation_size, num_classes) + + # self.heads = nn.Sequential(heads_layers) + + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + # if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): + # fan_in = self.heads.pre_logits.in_features + # nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + # nn.init.zeros_(self.heads.pre_logits.bias) + + # if isinstance(self.heads.head, nn.Linear): + # nn.init.zeros_(self.heads.head.weight) + # nn.init.zeros_(self.heads.head.bias) + + self.encoder_reduction = self.patch_size + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = hidden_dim + + def _process_input(self, x: Tensor) -> Tuple[Tensor, int, int, int]: + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + n, _, n_h, n_w = x.shape + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x, n, n_h, n_w + + def forward(self, x: Tensor) -> Tensor: + # Reshape and permute the input tensor + x, n, n_h, n_w = self._process_input(x) + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x, n_h, n_w) # Allows input image to be of any size. + + # Classifier "token" as used by standard language architectures + # x = x[:, 0] + + # x = self.heads(x) + + x = x[:, 1:, :] + x = rearrange(x, "n (h w) d -> n d h w", h=n_h, w=n_w) + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + return x # To be consistent with timm models + + +def _vision_transformer( + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + weights: str, + **kwargs: Any, +) -> VisionTransformer: + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if weights is not None: + weights = load_state_dict_from_url(weights, progress=kwargs.get("progress", True)) + missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) + if len(missing_keys) > 0: + print(f"Missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"Unexpected keys: {unexpected_keys}") + + return model + + +def interpolate_embeddings( + image_size: int, + patch_size: int, + pos_embedding: Tensor, + interpolation_mode: str = "bicubic", +) -> Tensor: + """This function helps interpolate positional embeddings during checkpoint loading, + especially when you want to apply a pre-trained model on images with different resolution. + + Args: + image_size (int): Image size of the new model. + patch_size (int): Patch size of the new model. + model_state (OrderedDict[str, Tensor]): State dict of the pre-trained model. + interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. + reset_heads (bool): If true, not copying the state of heads. Default: False. + + Returns: + Tensor: The interpolated positional embedding. + """ + # Shape of pos_embedding is (1, seq_length, hidden_dim) + n, seq_length, hidden_dim = pos_embedding.shape + if n != 1: + raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") + + new_seq_length = (image_size // patch_size) ** 2 + 1 + + # Need to interpolate the weights for the position embedding. + # We do this by reshaping the positions embeddings to a 2d grid, performing + # an interpolation in the (h, w) space and then reshaping back to a 1d grid. + if new_seq_length != seq_length: + # The class token embedding shouldn't be interpolated, so we split it up. + seq_length -= 1 + new_seq_length -= 1 + pos_embedding_token = pos_embedding[:, :1, :] + pos_embedding_img = pos_embedding[:, 1:, :] + + # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) + pos_embedding_img = pos_embedding_img.permute(0, 2, 1) + seq_length_1d = int(math.sqrt(seq_length)) + if seq_length_1d * seq_length_1d != seq_length: + raise ValueError( + f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}" + ) + + # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) + pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) + new_seq_length_1d = image_size // patch_size + + # Perform interpolation. + # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) + new_pos_embedding_img = nn.functional.interpolate( + pos_embedding_img, + size=new_seq_length_1d, + mode=interpolation_mode, + ) + + # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) + new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) + + # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) + new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) + new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) + + return new_pos_embedding + + return pos_embedding + + +def vit_b_16( + image_size: int = 224, + reduction: int = 16, + **kwargs: Any, +) -> VisionTransformer: + vit = _vision_transformer( + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights["vit_b_16"], + reduction=reduction, + **kwargs, + ) + if image_size != 224: + vit.image_size = image_size + new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic") + vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True) + return vit + + +def vit_b_32( + image_size: int = 224, + reduction: int = 32, + **kwargs: Any, +) -> VisionTransformer: + vit = _vision_transformer( + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights["vit_b_32"], + reduction=reduction, + **kwargs, + ) + if image_size != 224: + vit.image_size = image_size + new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic") + vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True) + return vit + + +def vit_l_16( + image_size: int = 224, + reduction: int = 16, + **kwargs: Any, +) -> VisionTransformer: + vit = _vision_transformer( + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights["vit_l_16"], + reduction=reduction, + **kwargs, + ) + if image_size != 224: + vit.image_size = image_size + new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic") + vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True) + return vit + + +def vit_l_32( + image_size: int = 224, + reduction: int = 32, + **kwargs: Any, +) -> VisionTransformer: + vit = _vision_transformer( + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights["vit_l_32"], + reduction=reduction, + **kwargs, + ) + if image_size != 224: + vit.image_size = image_size + new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic") + vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True) + return vit + + +def vit_h_14( + image_size: int = 224, + reduction: int = 14, + **kwargs: Any, +) -> VisionTransformer: + vit = _vision_transformer( + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + weights=weights["vit_h_14"], + reduction=reduction, + **kwargs, + ) + if image_size != 224: + vit.image_size = image_size + new_pos_embedding = interpolate_embeddings(image_size, 14, vit.state_dict()["encoder.pos_embedding"], "bicubic") + vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True) + return vit + diff --git a/models/encoder_decoder/__init__.py b/models/encoder_decoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..876bde097b78f23b30260c83318e056e8c8b00b7 --- /dev/null +++ b/models/encoder_decoder/__init__.py @@ -0,0 +1,17 @@ +from .vgg import vgg11 as vgg11_ae, vgg11_bn as vgg11_bn_ae +from .vgg import vgg13 as vgg13_ae, vgg13_bn as vgg13_bn_ae +from .vgg import vgg16 as vgg16_ae, vgg16_bn as vgg16_bn_ae +from .vgg import vgg19 as vgg19_ae, vgg19_bn as vgg19_bn_ae +from .resnet import resnet18 as resnet18_ae, resnet34 as resnet34_ae +from .resnet import resnet50 as resnet50_ae, resnet101 as resnet101_ae, resnet152 as resnet152_ae + +from .cannet import cannet, cannet_bn +from .csrnet import csrnet, csrnet_bn + + +__all__ = [ + "vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae", + "resnet18_ae", "resnet34_ae", "resnet50_ae", "resnet101_ae", "resnet152_ae", + "cannet", "cannet_bn", + "csrnet", "csrnet_bn", +] diff --git a/models/encoder_decoder/cannet.py b/models/encoder_decoder/cannet.py new file mode 100755 index 0000000000000000000000000000000000000000..aa5fc832b3c6ac2c3159c4eb17a00050ed10d34e --- /dev/null +++ b/models/encoder_decoder/cannet.py @@ -0,0 +1,85 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from typing import List, Optional + +from ..utils import _init_weights +from .csrnet import CSRNet, csrnet, csrnet_bn + +EPS = 1e-6 + + +class ContextualModule(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = 512, + sizes: List[int] = [1, 2, 3, 6], + ) -> None: + super().__init__() + self.scales = nn.ModuleList([self.__make_scale__(in_channels, size) for size in sizes]) + self.bottleneck = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.weight_net = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def __make_weight__(self, feature: Tensor, scale_feature: Tensor) -> Tensor: + weight_feature = feature - scale_feature + weight_feature = self.weight_net(weight_feature) + return F.sigmoid(weight_feature) + + def __make_scale__(self, channels: int, size: int) -> nn.Module: + return nn.Sequential( + nn.AdaptiveAvgPool2d(output_size=(size, size)), + nn.Conv2d(channels, channels, kernel_size=1, bias=False), + ) + + def forward(self, feature: Tensor) -> Tensor: + h, w = feature.shape[-2:] + multi_scales = [F.interpolate(input=scale(feature), size=(h, w), mode="bilinear") for scale in self.scales] + weights = [self.__make_weight__(feature, scale_feature) for scale_feature in multi_scales] + multi_scales = sum([multi_scales[i] * weights[i] for i in range(len(weights))]) / (sum(weights) + EPS) + overall_features = torch.cat([multi_scales, feature], dim=1) + overall_features = self.bottleneck(overall_features) + overall_features = self.relu(overall_features) + return overall_features + + +class CANNet(nn.Module): + def __init__( + self, + csrnet: CSRNet, + sizes: List[int] = [1, 2, 3, 6], + reduction: Optional[int] = 8, + ) -> None: + super().__init__() + assert isinstance(csrnet, CSRNet), f"csrnet should be an instance of CSRNet, got {type(csrnet)}." + assert isinstance(sizes, (tuple, list)), f"sizes should be a list or tuple, got {type(sizes)}." + assert len(sizes) > 0, f"Expected at least one size, got {len(sizes)}." + assert all([isinstance(size, int) for size in sizes]), f"Expected all size to be int, got {sizes}." + self.sizes = sizes + self.encoder_reduction = csrnet.encoder_reduction + self.reduction = self.encoder_reduction if reduction is None else reduction + + self.features = csrnet.features + self.decoder = csrnet.decoder + self.decoder.apply(_init_weights) + self.context = ContextualModule(512, 512, self.sizes) + self.context.apply(_init_weights) + + self.channels = csrnet.channels + + def forward(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.context(x) + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + x = self.decoder(x) + return x + + +def cannet(sizes=[1, 2, 3, 6], reduction: int = 8) -> CANNet: + return CANNet(csrnet(), sizes=sizes, reduction=reduction) + +def cannet_bn(sizes=[1, 2, 3, 6], reduction: int = 8) -> CANNet: + return CANNet(csrnet_bn(), sizes=sizes, reduction=reduction) diff --git a/models/encoder_decoder/csrnet.py b/models/encoder_decoder/csrnet.py new file mode 100755 index 0000000000000000000000000000000000000000..03479409019cbb5b7ba1b4cfc9f6a8f8e3140acb --- /dev/null +++ b/models/encoder_decoder/csrnet.py @@ -0,0 +1,54 @@ +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional + +from ..utils import _init_weights, make_vgg_layers, vgg_urls +from .vgg import _load_weights + +EPS = 1e-6 + + +encoder_cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512] +decoder_cfg = [512, 512, 512, 256, 128, 64] + + +class CSRNet(nn.Module): + def __init__( + self, + features: nn.Module, + decoder: nn.Module, + reduction: Optional[int] = None, + ) -> None: + super().__init__() + self.features = features + self.features.apply(_init_weights) + self.decoder = decoder + self.decoder.apply(_init_weights) + + self.encoder_reduction = 8 + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = 64 + + def forward(self, x: Tensor) -> Tensor: + x = self.features(x) + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + x = self.decoder(x) + return x + + +def csrnet(reduction: int = 8) -> CSRNet: + model = CSRNet( + make_vgg_layers(encoder_cfg, in_channels=3, batch_norm=False, dilation=1), + make_vgg_layers(decoder_cfg, in_channels=512, batch_norm=False, dilation=2), + reduction=reduction + ) + return _load_weights(model, vgg_urls["vgg16"]) + +def csrnet_bn(reduction: int = 8) -> CSRNet: + model = CSRNet( + make_vgg_layers(encoder_cfg, in_channels=3, batch_norm=True, dilation=1), + make_vgg_layers(decoder_cfg, in_channels=512, batch_norm=True, dilation=2), + reduction=reduction + ) + return _load_weights(model, vgg_urls["vgg16"]) diff --git a/models/encoder_decoder/resnet.py b/models/encoder_decoder/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..c8917aa060b5d8936dfe736d5bfa48f831f970b1 --- /dev/null +++ b/models/encoder_decoder/resnet.py @@ -0,0 +1,95 @@ +from torch import nn, Tensor +import torch.nn.functional as F +import timm +from typing import Union, Optional + +from ..utils import BasicBlock, Bottleneck, make_resnet_layers +from ..utils import _init_weights + + +model_configs = { + "resnet18.tv_in1k": { + "decoder_channels": [512, 256, 128], + }, + "resnet34.tv_in1k": { + "decoder_channels": [512, 256, 128], + }, + "resnet50.tv_in1k": { + "decoder_channels": [512, 256, 256, 128], + }, + "resnet101.tv_in1k": { + "decoder_channels": [512, 512, 256, 256, 128], + }, + "resnet152.tv_in1k": { + "decoder_channels": [512, 512, 512, 256, 256, 128], + }, +} + + +class ResNet(nn.Module): + def __init__( + self, + decoder_block: Union[BasicBlock, Bottleneck], + backbone: str = "resnet34.tv_in1k", + reduction: Optional[int] = None, + ) -> None: + super().__init__() + assert backbone in model_configs.keys(), f"Backbone should be in {model_configs.keys()}" + config = model_configs[backbone] + encoder = timm.create_model(backbone, pretrained=True, features_only=True, out_indices=(-1,)) + encoder_reduction = encoder.feature_info.reduction()[-1] + + if reduction <= 16: + if "resnet18" in backbone or "resnet34" in backbone: + encoder.layer4[0].conv1.stride = (1, 1) + encoder.layer4[0].downsample[0].stride = (1, 1) + else: + encoder.layer4[0].conv2.stride = (1, 1) + encoder.layer4[0].downsample[0].stride = (1, 1) + encoder_reduction = encoder_reduction // 2 + + self.encoder = encoder + self.encoder_reduction = encoder_reduction + + encoder_out_channels = self.encoder.feature_info.channels()[-1] + + decoder_channels = config["decoder_channels"] + self.decoder = make_resnet_layers( + block=decoder_block, + cfg=decoder_channels, + in_channels=encoder_out_channels, + dilation=1, + expansion=1, + ) + self.decoder.apply(_init_weights) + + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = decoder_channels[-1] + + def forward(self, x: Tensor) -> Tensor: + x = self.encoder(x)[-1] + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + x = self.decoder(x) + + return x + + +def resnet18(reduction: int = 32) -> ResNet: + return ResNet(decoder_block=BasicBlock, backbone="resnet18.tv_in1k", reduction=reduction) + + +def resnet34(reduction: int = 32) -> ResNet: + return ResNet(decoder_block=BasicBlock, backbone="resnet34.tv_in1k", reduction=reduction) + + +def resnet50(reduction: int = 32) -> ResNet: + return ResNet(decoder_block=Bottleneck, backbone="resnet50.tv_in1k", reduction=reduction) + + +def resnet101(reduction: int = 32) -> ResNet: + return ResNet(decoder_block=Bottleneck, backbone="resnet101.tv_in1k", reduction=reduction) + + +def resnet152(reduction: int = 32) -> ResNet: + return ResNet(decoder_block=Bottleneck, backbone="resnet152.tv_in1k", reduction=reduction) diff --git a/models/encoder_decoder/vgg.py b/models/encoder_decoder/vgg.py new file mode 100755 index 0000000000000000000000000000000000000000..a054bda669f0c941f74abb294105aec836bf6237 --- /dev/null +++ b/models/encoder_decoder/vgg.py @@ -0,0 +1,85 @@ +# The model used in the paper Distribution Matching for Crowd Counting. +# Code adapted from https://github.com/cvlab-stonybrook/DM-Count/blob/master/models.py +from torch import nn, Tensor +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url +from typing import Optional + +from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls +from ..utils import _init_weights + + + +class VGG(nn.Module): + def __init__( + self, + features: nn.Module, + reduction: Optional[int] = None, + ) -> None: + super().__init__() + self.features = features + self.reg_layer = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.reg_layer.apply(_init_weights) + # Remove the density layer, as the output from this model is not final and will be further processed. + # self.density_layer = nn.Sequential(nn.Conv2d(128, 1, 1), nn.ReLU()) + self.encoder_reduction = 16 + self.reduction = self.encoder_reduction if reduction is None else reduction + self.channels = 128 + + def forward(self, x: Tensor) -> Tensor: + x = self.features(x) + if self.encoder_reduction != self.reduction: + x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") + x = self.reg_layer(x) + # x = self.density_layer(x) + return x + + +def _load_weights(model: VGG, url: str) -> VGG: + state_dict = load_state_dict_from_url(url) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Loading pre-trained weights") + if len(missing_keys) > 0: + print(f"Missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"Unexpected keys: {unexpected_keys}") + return model + + +def vgg11(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg11"]) + +def vgg11_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg11_bn"]) + +def vgg13(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg13"]) + +def vgg13_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg13_bn"]) + +def vgg16(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg16"]) + +def vgg16_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg16_bn"]) + +def vgg19(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction) + return _load_weights(model, vgg_urls["vgg19"]) + +def vgg19_bn(reduction: int = 8) -> VGG: + model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction) + return _load_weights(model, vgg_urls["vgg19_bn"]) diff --git a/models/model.py b/models/model.py new file mode 100755 index 0000000000000000000000000000000000000000..bb78083b7acce72d671c7a28497e11c38b736556 --- /dev/null +++ b/models/model.py @@ -0,0 +1,112 @@ +import torch +from torch import nn, Tensor +import os +from typing import List, Tuple, Union, Callable +from functools import partial + +from .utils import _init_weights + +from . import encoder +from . import encoder_decoder +from .encoder import _timm_encoder + + +curr_dir = os.path.abspath(os.path.dirname(__file__)) + + +class Regressor(nn.Module): + def __init__(self, backbone: nn.Module) -> None: + super().__init__() + self.backbone = backbone + self.reduction = backbone.reduction + + self.regressor = nn.Sequential( + nn.Conv2d(backbone.channels, 1, kernel_size=1), + nn.ReLU(inplace=True), + ) + self.regressor.apply(_init_weights) + self.bins = None + self.anchor_points = None + + def forward(self, x: Tensor) -> Tensor: + x = self.backbone(x) + x = self.regressor(x) + return x + + +class Classifier(nn.Module): + def __init__( + self, + backbone: nn.Module, + bins: List[Tuple[float, float]], + anchor_points: List[float], + ) -> None: + super().__init__() + self.backbone = backbone + self.reduction = backbone.reduction + + assert len(bins) == len(anchor_points), f"Expected bins and anchor_points to have the same length, got {len(bins)} and {len(anchor_points)}" + assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}" + assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, anchor_points)), f"Expected anchor_points to be within the range of the corresponding bin, got {bins} and {anchor_points}" + + self.bins = bins + self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1) + + if backbone.channels > 512: + self.classifier = nn.Sequential( + nn.Conv2d(backbone.channels, 512, kernel_size=1), # serves as a linear layer for feature vectors at each pixel + nn.ReLU(inplace=True), + nn.Conv2d(512, len(self.bins), kernel_size=1), + ) + else: + self.classifier = nn.Conv2d(backbone.channels, len(self.bins), kernel_size=1) + + self.classifier.apply(_init_weights) + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: + x = self.backbone(x) + x = self.classifier(x) # shape (B, C, H, W), where C = len(bins), x is the logits + + probs = x.softmax(dim=1) # shape (B, C, H, W) + exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # shape (B, 1, H, W) + if self.training: + return x, exp + else: + return exp + + +def _get_backbone(backbone: str, input_size: int, reduction: int) -> Callable: + assert "clip" not in backbone, f"This function does not support CLIP model, got {backbone}" + + if backbone in ["vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14"]: + return partial(getattr(encoder, backbone), image_size=input_size, reduction=reduction) + elif backbone in ["vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn"]: + return partial(getattr(encoder, backbone), reduction=reduction) + elif backbone in ["vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae"]: + return partial(getattr(encoder_decoder, backbone), reduction=reduction) + elif backbone in ["resnet18_ae", "resnet34_ae", "resnet50_ae", "resnet101_ae", "resnet152_ae"]: + return partial(getattr(encoder_decoder, backbone), reduction=reduction) + elif backbone in ["cannet", "cannet_bn", "csrnet", "csrnet_bn"]: + return partial(getattr(encoder_decoder, backbone), reduction=reduction) + else: + return partial(_timm_encoder, backbone=backbone, reduction=reduction) + + +def _regressor( + backbone: str, + input_size: int, + reduction: int, +) -> Regressor: + backbone = _get_backbone(backbone.lower(), input_size, reduction) + return Regressor(backbone()) + + +def _classifier( + backbone: nn.Module, + input_size: int, + reduction: int, + bins: List[Tuple[float, float]], + anchor_points: List[float], +) -> Classifier: + backbone = _get_backbone(backbone.lower(), input_size, reduction) + return Classifier(backbone(), bins, anchor_points) diff --git a/models/utils.py b/models/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..843ff7e632200481490b2332ac1bdbf1db2289ec --- /dev/null +++ b/models/utils.py @@ -0,0 +1,452 @@ +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from functools import partial +from typing import Callable, Optional, Sequence, Tuple, Union, Any, List, TypeVar, List +from types import FunctionType +from itertools import repeat +import warnings +import os +from collections.abc import Iterable + +V = TypeVar("V") +curr_dir = os.path.dirname(os.path.abspath(__file__)) + + +vgg_urls = { + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", +} + +vgg_cfgs = { + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512] +} + + +def _log_api_usage_once(obj: Any) -> None: + + """ + Logs API usage(module and name) within an organization. + In a large ecosystem, it's often useful to track the PyTorch and + TorchVision APIs usage. This API provides the similar functionality to the + logging module in the Python stdlib. It can be used for debugging purpose + to log which methods are used and by default it is inactive, unless the user + manually subscribes a logger via the `SetAPIUsageLogger method `_. + Please note it is triggered only once for the same API call within a process. + It does not collect any data from open-source users since it is no-op by default. + For more information, please refer to + * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; + * Logging policy: https://github.com/pytorch/vision/issues/5052; + + Args: + obj (class instance or method): an object to extract info from. + """ + module = obj.__module__ + if not module.startswith("torchvision"): + module = f"torchvision.internal.{module}" + name = obj.__class__.__name__ + if isinstance(obj, FunctionType): + name = obj.__name__ + torch._C._log_api_usage_once(f"{module}.{name}") + + +def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: + """ + Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. + Otherwise, we will make a tuple of length n, all with value of x. + reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 + + Args: + x (Any): input value + n (int): length of the resulting tuple + """ + if isinstance(x, Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + +class ConvNormActivation(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]] = 3, + stride: Union[int, Tuple[int, ...]] = 1, + padding: Optional[Union[int, Tuple[int, ...], str]] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: Union[int, Tuple[int, ...]] = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, + ) -> None: + + if padding is None: + if isinstance(kernel_size, int) and isinstance(dilation, int): + padding = (kernel_size - 1) // 2 * dilation + else: + _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation) + kernel_size = _make_ntuple(kernel_size, _conv_dim) + dilation = _make_ntuple(dilation, _conv_dim) + padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) + if bias is None: + bias = norm_layer is None + + layers = [ + conv_layer( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + ] + + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + super().__init__(*layers) + _log_api_usage_once(self) + self.out_channels = out_channels + + if self.__class__ == ConvNormActivation: + warnings.warn( + "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." + ) + + +class Conv2dNormActivation(ConvNormActivation): + """ + Configurable block used for Convolution2d-Normalization-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int], str]] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: Union[int, Tuple[int, int]] = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + torch.nn.Conv2d, + ) + + +class MLP(torch.nn.Sequential): + """This block implements the multi-layer perceptron (MLP) module. + + Args: + in_channels (int): Number of channels of the input + hidden_channels (List[int]): List of the hidden channel dimensions + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` + inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. + Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. + bias (bool): Whether to use bias in the linear layer. Default ``True`` + dropout (float): The probability for the dropout layer. Default: 0.0 + """ + + def __init__( + self, + in_channels: int, + hidden_channels: List[int], + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + inplace: Optional[bool] = None, + bias: bool = True, + dropout: float = 0.0, + ): + # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: + # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py + params = {} if inplace is None else {"inplace": inplace} + + layers = [] + in_dim = in_channels + for hidden_dim in hidden_channels[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) + if norm_layer is not None: + layers.append(norm_layer(hidden_dim)) + layers.append(activation_layer(**params)) + layers.append(torch.nn.Dropout(dropout, **params)) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) + layers.append(torch.nn.Dropout(dropout, **params)) + + super().__init__(*layers) + _log_api_usage_once(self) + + +def conv3x3( + in_channels: int, + out_channels: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1, +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.stride = stride + if in_channels != out_channels: + self.downsample = nn.Sequential( + conv1x1(in_channels, out_channels), + nn.BatchNorm2d(out_channels), + ) + else: + self.downsample = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + # print("Conv1 weight shape:", self.conv1.weight.shape) + # print("Conv1 weight device:", self.conv1.weight.device) + # print("Input tensor device:", x.device) + # print("Input tensor requires_grad:", x.requires_grad) + # print("Conv1 weight requires_grad:", self.conv1.weight.requires_grad) + # print(f"Memory before conv1: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + out = self.conv1(x) + # print(f"Memory after conv1: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + # print(f"Conv1 output shape: {out.shape}") + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += self.downsample(identity) + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + expansion: int = 4, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(out_channels * (base_width / 64.0)) * groups + self.expansion = expansion + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(in_channels, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + if in_channels != out_channels: + self.downsample = nn.Sequential( + conv1x1(in_channels, out_channels), + nn.BatchNorm2d(out_channels), + ) + else: + self.downsample = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += self.downsample(identity) + out = self.relu(out) + + return out + + +def _init_weights(model: nn.Module) -> None: + for m in model.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1.) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + +class Upsample(nn.Module): + def __init__( + self, + size: Union[int, Tuple[int, int]] = None, + scale_factor: Union[float, Tuple[float, float]] = None, + mode: str = "nearest", + align_corners: bool = False, + antialias: bool = False, + ) -> None: + super().__init__() + self.interpolate = partial( + F.interpolate, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + antialias=antialias, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.interpolate(x) + + +def make_vgg_layers(cfg: List[Union[str, int]], in_channels: int = 3, batch_norm: bool = False, dilation: int = 1) -> nn.Sequential: + layers = [] + for v in cfg: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + elif v == "U": + layers += [Upsample(scale_factor=2, mode="bilinear")] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=dilation, dilation=dilation) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +def make_resnet_layers( + block: Union[BasicBlock, Bottleneck], + cfg: List[Union[int, str]], + in_channels: int, + dilation: int = 1, + expansion: int = 1, +) -> nn.Sequential: + layers = [] + for v in cfg: + if v == "U": + layers.append(Upsample(scale_factor=2, mode="bilinear")) + else: + layers.append(block( + in_channels=in_channels, + out_channels=v, + dilation=dilation, + expansion=expansion, + )) + in_channels = v + + layers = nn.Sequential(*layers) + layers.apply(_init_weights) + return layers diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..24f5808a1422d17223c76f9636c30fbfefbd97ca --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +einops==0.7.0 +ftfy==6.1.3 +numpy==1.26.4 +Pillow==10.2.0 +regex==2023.12.25 +scipy==1.12.0 +setuptools==69.1.1 +tensorboardX==2.6.2.2 +timm==0.9.16 +torch==2.2.1 +torchvision==0.17.1 +tqdm==4.66.2 +scipy +scikit-learn +streamlit +gradio +matplotlib +seaborn +huggingface_hub +onnx==1.17.0 +onnxruntime-gpu==1.21.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d7536d194e88d692a649283d1e636f26a7196a92 --- /dev/null +++ b/setup.py @@ -0,0 +1,45 @@ +from setuptools import setup, find_packages + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="clipebc", + version="0.1.0", + author="jungseoik", + author_email="si.jung@pia.space", + description="Crowd counting with CLIP-EBC", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/jungseoik/CLIP_EBC", + packages=find_packages(), + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], + python_requires=">=3.12.4", + install_requires=[ + "einops==0.7.0", + "ftfy==6.1.3", + "numpy==1.26.4", + "Pillow==10.2.0", + "regex==2023.12.25", + "scipy==1.12.0", + "tensorboardX==2.6.2.2", + "timm==0.9.16", + "torch==2.2.1", + "torchvision==0.17.1", + "tqdm==4.66.2", + "scikit-learn", + "matplotlib", + "seaborn" + ], + extras_require={ + "app": ["streamlit", "gradio"], + "dev": ["pytest", "black", "isort"], + } +) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..269bf914eb48fd6905d4d3775c562597fd535707 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,13 @@ +from .ddp_utils import reduce_mean, setup, cleanup, init_seeds, barrier +from .eval_utils import calculate_errors, resize_density_map, sliding_window_predict +from .log_utils import get_logger, get_config, get_writer, print_epoch, print_train_result, print_eval_result, update_train_result, update_eval_result, log, update_loss_info +from .train_utils import cosine_annealing_warm_restarts, get_loss_fn, get_optimizer, load_checkpoint, save_checkpoint +from .data_utils import get_dataloader + + +__all__ = [ + "reduce_mean", "setup", "cleanup", "init_seeds", "barrier", + "calculate_errors", "resize_density_map", "sliding_window_predict", + "get_logger", "get_config", "get_writer", "print_epoch", "print_train_result", "print_eval_result", "update_train_result", "update_eval_result", "log", "update_loss_info", + "get_dataloader", "get_loss_fn", "get_optimizer", "load_checkpoint", "save_checkpoint", +] diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5c03451662b9bfa9ac333e41d3ed52f9e45ef9bd --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,78 @@ +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.transforms.v2 import Compose +import os, sys +from argparse import ArgumentParser +from typing import Union, Tuple + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +# import datasets + + +def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: + if split == "train": # train, strong augmentation + transforms = Compose([ + datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)), + datasets.RandomHorizontalFlip(), + datasets.RandomApply([ + datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue), + datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)), + datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness), + ], p=(args.jitter_prob, args.blur_prob, args.noise_prob)), + ]) + + elif args.sliding_window: + if args.resize_to_multiple: + transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) + elif args.zero_pad_to_multiple: + transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride) + else: + transforms = None + + else: + transforms = None + + dataset = datasets.Crowd( + dataset=args.dataset, + split=split, + transforms=transforms, + sigma=None, + return_filename=False, + num_crops=args.num_crops if split == "train" else 1, + ) + + if ddp and split == "train": # data_loader for training in DDP + sampler = DistributedSampler(dataset) + data_loader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=datasets.collate_fn, + ) + return data_loader, sampler + + elif split == "train": # data_loader for training + data_loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=datasets.collate_fn, + ) + return data_loader, None + + else: # data_loader for evaluation + data_loader = DataLoader( + dataset, + batch_size=1, # Use batch size 1 for evaluation + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=datasets.collate_fn, + ) + return data_loader diff --git a/utils/ddp_utils.py b/utils/ddp_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..e94edb2d51e7801a8ae1e143192ed3c994e808e2 --- /dev/null +++ b/utils/ddp_utils.py @@ -0,0 +1,44 @@ +import torch +from torch import Tensor +import torch.distributed as dist +import numpy as np +import random +import os + + +def reduce_mean(tensor: Tensor, nprocs: int) -> Tensor: + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= nprocs + return rt + + +def setup(local_rank: int, nprocs: int) -> None: + if nprocs > 1: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=local_rank, world_size=nprocs) + else: + print("Single process. No need to setup dist.") + + +def cleanup(ddp: bool = True) -> None: + if ddp: + dist.destroy_process_group() + + +def init_seeds(seed: int, cuda_deterministic: bool = False) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if cuda_deterministic: # slower, but reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: # faster, not reproducible + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def barrier(ddp: bool = True) -> None: + if ddp: + dist.barrier() diff --git a/utils/eval_utils.py b/utils/eval_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..d1a0179790c954d7332dea1a17b84a3e3991c205 --- /dev/null +++ b/utils/eval_utils.py @@ -0,0 +1,130 @@ +import torch +from torch import Tensor, nn +import torch.nn.functional as F +import numpy as np +from typing import Dict, Tuple, Union + +# 캐시 메모리 정리 + +# PyTorch 캐시 정리 (Python 코드에서 실행) +import torch +torch.cuda.empty_cache() + +def calculate_errors(pred_counts: np.ndarray, gt_counts: np.ndarray) -> Dict[str, float]: + assert isinstance(pred_counts, np.ndarray), f"Expected numpy.ndarray, got {type(pred_counts)}" + assert isinstance(gt_counts, np.ndarray), f"Expected numpy.ndarray, got {type(gt_counts)}" + assert len(pred_counts) == len(gt_counts), f"Length of predictions and ground truths should be equal, but got {len(pred_counts)} and {len(gt_counts)}" + errors = { + "mae": np.mean(np.abs(pred_counts - gt_counts)), + "rmse": np.sqrt(np.mean((pred_counts - gt_counts) ** 2)), + } + return errors + + +def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor: + # 원본 density map의 전체 합을 저장 + x_sum = torch.sum(x, dim=(-1, -2)) + # bilinear interpolation으로 원본 이미지 크기로 resize + x = F.interpolate(x, size=size, mode="bilinear") + # resize 후에도 전체 합이 보존되도록 scaling factor 계산 + scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, + nan=0.0, posinf=0.0, neginf=0.0) + # scaling factor를 적용하여 전체 합 보존 + return x * scale_factor + + +def sliding_window_predict( + model: nn.Module, + image: Tensor, + window_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], +) -> Tensor: + """ + Generate the density map for an image using the sliding window method. Overlapping regions will be averaged. + + Args: + model (nn.Module): The model to use. + image (Tensor): The image (1, c, h, w) to generate the density map for. The batch size must be 1 due to varying image sizes. + window_size (Union[int, Tuple[int, int]]): The size of the window. + stride (Union[int, Tuple[int, int]]): The step size of the window. + """ + assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}" + window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size + stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride + window_size = tuple(window_size) + stride = tuple(stride) + assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}" + assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}" + assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}" + + image_height, image_width = image.shape[-2:] + window_height, window_width = window_size + stride_height, stride_width = stride + + num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) + num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) + + reduction = model.reduction if hasattr(model, "reduction") else 1 # reduction factor of the model. For example, if reduction = 8, then the density map will be reduced by 8x. + windows = [] + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start = i * stride_height, j * stride_width + x_end, y_end = x_start + window_height, y_start + window_width + if x_end > image_height: + x_start, x_end = image_height - window_height, image_height + if y_end > image_width: + y_start, y_end = image_width - window_width, image_width + + window = image[:, :, x_start:x_end, y_start:y_end] + windows.append(window) + + windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w) + + # model.eval() + # with torch.no_grad(): + # preds = model(windows) + +# # # # # # # # # # # # # # # # # # # # # # 여기서부터 batch 단위로 추론 + all_preds = [] + max_batch_size = 8 + model.eval() + with torch.no_grad(): + for start_idx in range(0, windows.size(0), max_batch_size): + end_idx = start_idx + max_batch_size + batch_windows = windows[start_idx:end_idx] # (batch_size, c, h, w) + + # 추론 + # 입력 정보 확인 + print("Input shape:", batch_windows.shape) + print("Input dtype:", batch_windows.dtype) + batch_preds = model(batch_windows) + # 출력 정보 확인 + print("Output shape:", batch_preds.shape) + print("Output dtype:", batch_preds.dtype) + all_preds.append(batch_preds.cpu()) + + preds = torch.cat(all_preds, dim=0) # 다시 붙이고 + preds = preds.to(image.device) # device로 보내기(필요하면) +# # # # # # # # # # # # # # # # # # # # # # # # # # # + preds = preds.cpu().detach().numpy() + + + # assemble the density map + pred_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32) + count_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32) + idx = 0 + for i in range(num_rows): + for j in range(num_cols): + x_start, y_start = i * stride_height, j * stride_width + x_end, y_end = x_start + window_height, y_start + window_width + if x_end > image_height: + x_start, x_end = image_height - window_height, image_height + if y_end > image_width: + y_start, y_end = image_width - window_width, image_width + + pred_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += preds[idx, :, :, :] + count_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += 1. + idx += 1 + + pred_map /= count_map # average the overlapping regions + return torch.tensor(pred_map).unsqueeze(0) # shape: (1, 1, h // reduction, w // reduction) diff --git a/utils/log_utils.py b/utils/log_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..34ef2b7a1ca6d6c1051798d513bddf845291041b --- /dev/null +++ b/utils/log_utils.py @@ -0,0 +1,147 @@ +import torch +from torch import Tensor +from tensorboardX import SummaryWriter +import logging +import os +from typing import Dict, Union, Optional, List, Tuple +from collections import OrderedDict + + +def get_logger(log_file: str) -> logging.Logger: + logger = logging.getLogger(log_file) + logger.setLevel(logging.DEBUG) + fh = logging.FileHandler(log_file) + fh.setLevel(logging.DEBUG) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + fh.setFormatter(formatter) + logger.addHandler(ch) + logger.addHandler(fh) + return logger + + +def get_config(config: Dict, mute: bool = False) -> str: + config = config.copy() + config = "\n".join([f"{k.ljust(15)}:\t{v}" for k, v in config.items()]) + if not mute: + print(config) + return config + + +def get_writer(ckpt_dir: str) -> SummaryWriter: + return SummaryWriter(log_dir=os.path.join(ckpt_dir, "logs")) + + +def print_epoch(epoch: int, total_epochs: int, mute: bool = False) -> Union[str, None]: + digits = len(str(total_epochs)) + info = f"Epoch: {(epoch):0{digits}d} / {total_epochs:0{digits}d}" + if mute: + return info + print(info) + + +def print_train_result(loss_info: Dict[str, float], mute: bool = False) -> Union[str, None]: + loss_info = [f"{k}: {v};" for k, v in loss_info.items()] + info = "Training: " + " ".join(loss_info) + if mute: + return info + print(info) + + +def print_eval_result(curr_scores: Dict[str, float], best_scores: Dict[str, float], mute: bool = False) -> Union[str, None]: + scores = [] + for k in curr_scores.keys(): + info = f"Curr {k}: {curr_scores[k]:.4f}; \t Best {k}: " + info += " ".join([f"{best_scores[k][i]:.4f};" for i in range(len(best_scores[k]))]) + scores.append(info) + + info = "Evaluation:\n" + "\n".join(scores) + if mute: + return info + print(info) + + +def update_train_result(epoch: int, loss_info: Dict[str, float], writer: SummaryWriter) -> None: + for k, v in loss_info.items(): + writer.add_scalar(f"train/{k}", v, epoch) + + +def update_eval_result( + epoch: int, + curr_scores: Dict[str, float], + hist_scores: Dict[str, List[float]], + best_scores: Dict[str, List[float]], + writer: SummaryWriter, + state_dict: OrderedDict[str, Tensor], + ckpt_dir: str, +) -> Tuple[Dict[str, List[float]], Dict[str, float]]: + os.makedirs(ckpt_dir, exist_ok=True) + for k, v in curr_scores.items(): + hist_scores[k].append(v) + writer.add_scalar(f"val/{k}", v, epoch) + + # best_scores[k][0] is the best score. Smaller is better. + # Find the location idx where the new score v should be inserted + loc = None + for i in range(len(best_scores[k])): + if v < best_scores[k][i]: + best_scores[k].insert(i, v) # Add the new best score to the location i + loc = i + break + + # If the new score is better than the worst best score + if loc is not None: + # Update the best scores + best_scores[k] = best_scores[k][:len(best_scores[k]) - 1] + + # Rename the best_{k}_{i}.pth to best_{k}_{i+1}.pth, best_{k}_{i+1}.pth to best_{k}_{i+2}.pth ... + for i in range(len(best_scores[k]) - 1, loc, -1): + if os.path.exists(os.path.join(ckpt_dir, f"best_{k}_{i-1}.pth")): + os.rename(os.path.join(ckpt_dir, f"best_{k}_{i-1}.pth"), os.path.join(ckpt_dir, f"best_{k}_{i}.pth")) + + # Save the best checkpoint + torch.save(state_dict, os.path.join(ckpt_dir, f"best_{k}_{loc}.pth")) + + return hist_scores, best_scores + + +def update_loss_info(hist_scores: Union[Dict[str, List[float]], None], curr_scores: Dict[str, float]) -> Dict[str, List[float]]: + assert all([isinstance(v, float) for v in curr_scores.values()]), f"Expected all values to be float, got {curr_scores}" + if hist_scores is None or len(hist_scores) == 0: + hist_scores = {k: [v] for k, v in curr_scores.items()} + else: + for k, v in curr_scores.items(): + hist_scores[k].append(v) + return hist_scores + + +def log( + logger: logging.Logger, + epoch: int, + total_epochs: int, + loss_info: Optional[Dict[str, float]] = None, + curr_scores: Optional[Dict[str, float]] = None, + best_scores: Optional[Dict[str, float]] = None, + message: Optional[str] = None, +) -> None: + if epoch is None: + assert total_epochs is None, f"Expected total_epochs to be None when epoch is None, got {total_epochs}" + msg = "" + else: + assert total_epochs is not None, f"Expected total_epochs to be not None when epoch is not None, got {total_epochs}" + msg = print_epoch(epoch, total_epochs, mute=True) + + if loss_info is not None: + msg += "\n" if len(msg) > 0 else "" + msg += print_train_result(loss_info, mute=True) + + if curr_scores is not None: + assert best_scores is not None, f"Expected best_scores to be not None when curr_scores is not None, got {best_scores}" + msg += "\n" if len(msg) > 0 else "" + msg += print_eval_result(curr_scores, best_scores, mute=True) + + msg += message if message is not None else "" + + logger.info(msg) diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..03e27fc972cfa1d7a0f04b9cc6d2b91d7e4e9ec0 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,157 @@ +import torch +from torch import nn, Tensor + +from torch.optim import Adam +from torch.cuda.amp import GradScaler +from torch.optim.lr_scheduler import LambdaLR + +from functools import partial +from argparse import ArgumentParser + +import os, sys, math +from typing import Union, Tuple, Dict, List +from collections import OrderedDict + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +# import losses + + +def cosine_annealing_warm_restarts( + epoch: int, + base_lr: float, + warmup_epochs: int, + warmup_lr: float, + T_0: int, + T_mult: int, + eta_min: float, +) -> float: + """ + Learning rate scheduler. + The learning rate will linearly increase from warmup_lr to lr in the first warmup_epochs epochs. + Then, the learning rate will follow the cosine annealing with warm restarts strategy. + """ + assert epoch >= 0, f"epoch must be non-negative, got {epoch}." + assert isinstance(warmup_epochs, int) and warmup_epochs >= 0, f"warmup_epochs must be non-negative, got {warmup_epochs}." + assert isinstance(warmup_lr, float) and warmup_lr > 0, f"warmup_lr must be positive, got {warmup_lr}." + assert isinstance(T_0, int) and T_0 >= 1, f"T_0 must be greater than or equal to 1, got {T_0}." + assert isinstance(T_mult, int) and T_mult >= 1, f"T_mult must be greater than or equal to 1, got {T_mult}." + assert isinstance(eta_min, float) and eta_min > 0, f"eta_min must be positive, got {eta_min}." + assert isinstance(base_lr, float) and base_lr > 0, f"base_lr must be positive, got {base_lr}." + assert base_lr > eta_min, f"base_lr must be greater than eta_min, got base_lr={base_lr} and eta_min={eta_min}." + assert warmup_lr >= eta_min, f"warmup_lr must be greater than or equal to eta_min, got warmup_lr={warmup_lr} and eta_min={eta_min}." + + if epoch < warmup_epochs: + lr = warmup_lr + (base_lr - warmup_lr) * epoch / warmup_epochs + else: + epoch -= warmup_epochs + if T_mult == 1: + T_cur = epoch % T_0 + T_i = T_0 + else: + n = int(math.log((epoch / T_0 * (T_mult - 1) + 1), T_mult)) + T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1) + T_i = T_0 * T_mult ** (n) + + lr = eta_min + (base_lr - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + + return lr / base_lr + + +def get_loss_fn(args: ArgumentParser) -> nn.Module: + if args.bins is None: + assert args.weight_ot is not None and args.weight_tv is not None, f"Expected weight_ot and weight_tv to be not None, got {args.weight_ot} and {args.weight_tv}" + loss_fn = losses.DMLoss( + input_size=args.input_size, + reduction=args.reduction, + ) + else: + loss_fn = losses.DACELoss( + bins=args.bins, + reduction=args.reduction, + weight_count_loss=args.weight_count_loss, + count_loss=args.count_loss, + input_size=args.input_size, + ) + return loss_fn + + +def get_optimizer(args: ArgumentParser, model: nn.Module) -> Tuple[Adam, LambdaLR]: + optimizer = Adam( + params=filter(lambda p: p.requires_grad, model.parameters()), + lr=args.lr, + weight_decay=args.weight_decay + ) + + scheduler = LambdaLR( + optimizer=optimizer, + lr_lambda=partial( + cosine_annealing_warm_restarts, + warmup_epochs=args.warmup_epochs, + warmup_lr=args.warmup_lr, + T_0=args.T_0, + T_mult=args.T_mult, + eta_min=args.eta_min, + base_lr=args.lr + ), + ) + + return optimizer, scheduler + + +def load_checkpoint( + args: ArgumentParser, + model: nn.Module, + optimizer: Adam, + scheduler: LambdaLR, + grad_scaler: GradScaler, +) -> Tuple[nn.Module, Adam, Union[LambdaLR, None], GradScaler, int, Union[Dict[str, float], None], Dict[str, List[float]], Dict[str, float]]: + ckpt_path = os.path.join(args.ckpt_dir, "ckpt.pth") + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path) + model.load_state_dict(ckpt["model_state_dict"]) + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + start_epoch = ckpt["epoch"] + loss_info = ckpt["loss_info"] + hist_scores = ckpt["hist_scores"] + best_scores = ckpt["best_scores"] + + if scheduler is not None: + scheduler.load_state_dict(ckpt["scheduler_state_dict"]) + if grad_scaler is not None: + grad_scaler.load_state_dict(ckpt["grad_scaler_state_dict"]) + + print(f"Loaded checkpoint from {ckpt_path}.") + + else: + start_epoch = 1 + loss_info, hist_scores = None, {"mae": [], "rmse": []} + best_scores = {k: [torch.inf] * args.save_best_k for k in hist_scores.keys()} + print(f"Checkpoint not found at {ckpt_path}.") + + return model, optimizer, scheduler, grad_scaler, start_epoch, loss_info, hist_scores, best_scores + + +def save_checkpoint( + epoch: int, + model_state_dict: OrderedDict[str, Tensor], + optimizer_state_dict: OrderedDict[str, Tensor], + scheduler_state_dict: OrderedDict[str, Tensor], + grad_scaler_state_dict: OrderedDict[str, Tensor], + loss_info: Dict[str, List[float]], + hist_scores: Dict[str, List[float]], + best_scores: Dict[str, float], + ckpt_dir: str, +) -> None: + ckpt = { + "epoch": epoch, + "model_state_dict": model_state_dict, + "optimizer_state_dict": optimizer_state_dict, + "scheduler_state_dict": scheduler_state_dict, + "grad_scaler_state_dict": grad_scaler_state_dict, + "loss_info": loss_info, + "hist_scores": hist_scores, + "best_scores": best_scores, + } + torch.save(ckpt, os.path.join(ckpt_dir, "ckpt.pth"))