Spaces:
Runtime error
Runtime error
[first]
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __init__.py +3 -0
- app.py +60 -0
- assets/__init__.py +35 -0
- configs/reduction_16.json +33 -0
- configs/reduction_32.json +56 -0
- configs/reduction_8.json +129 -0
- custom/airport_color.py +54 -0
- custom/clip_ebc.py +346 -0
- custom/clip_ebc_onnx.py +465 -0
- custom/clip_ebc_tensorrt.py +603 -0
- custom/init_get_model.py +0 -0
- custom/json2seg.py +16 -0
- custom/mock_gen.py +80 -0
- custom/visual.py +100 -0
- losses/__init__.py +7 -0
- losses/bregman_pytorch.py +144 -0
- losses/dace_loss.py +70 -0
- losses/dm_loss.py +124 -0
- losses/utils.py +9 -0
- main.py +75 -0
- models/__init__.py +49 -0
- models/clip/__init__.py +7 -0
- models/clip/_clip/__init__.py +273 -0
- models/clip/_clip/blocks.py +137 -0
- models/clip/_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- models/clip/_clip/configs/clip_image_encoder_resnet101.json +13 -0
- models/clip/_clip/configs/clip_image_encoder_resnet50.json +13 -0
- models/clip/_clip/configs/clip_image_encoder_resnet50x16.json +13 -0
- models/clip/_clip/configs/clip_image_encoder_resnet50x4.json +13 -0
- models/clip/_clip/configs/clip_image_encoder_resnet50x64.json +13 -0
- models/clip/_clip/configs/clip_image_encoder_vit_b_16.json +8 -0
- models/clip/_clip/configs/clip_image_encoder_vit_b_32.json +8 -0
- models/clip/_clip/configs/clip_image_encoder_vit_l_14.json +8 -0
- models/clip/_clip/configs/clip_image_encoder_vit_l_14_336px.json +8 -0
- models/clip/_clip/configs/clip_resnet101.json +17 -0
- models/clip/_clip/configs/clip_resnet50.json +17 -0
- models/clip/_clip/configs/clip_resnet50x16.json +17 -0
- models/clip/_clip/configs/clip_resnet50x4.json +17 -0
- models/clip/_clip/configs/clip_resnet50x64.json +17 -0
- models/clip/_clip/configs/clip_text_encoder_resnet101.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_resnet50.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_resnet50x16.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_resnet50x4.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_resnet50x64.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_vit_b_16.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_vit_b_32.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_vit_l_14.json +8 -0
- models/clip/_clip/configs/clip_text_encoder_vit_l_14_336px.json +8 -0
- models/clip/_clip/configs/clip_vit_b_16.json +12 -0
- models/clip/_clip/configs/clip_vit_b_32.json +12 -0
__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from custom.clip_ebc import ClipEBC
|
2 |
+
|
3 |
+
__all__ = ["ClipEBC"]
|
app.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from custom.clip_ebc_onnx import ClipEBCOnnx
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
# ONNX 모델 초기화
|
7 |
+
model = ClipEBCOnnx()
|
8 |
+
|
9 |
+
def predict_crowd(image):
|
10 |
+
"""
|
11 |
+
이미지를 받아서 군중 수를 예측하고 시각화 결과를 반환합니다.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
image: Gradio에서 받은 이미지 (numpy array)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple: (예측된 군중 수, 밀도 맵 시각화, 점 시각화)
|
18 |
+
"""
|
19 |
+
count = model.predict(image)
|
20 |
+
|
21 |
+
# 밀도 맵 시각화
|
22 |
+
fig_density, density_map = model.visualize_density_map()
|
23 |
+
plt.close(fig_density) # 메모리 누수 방지
|
24 |
+
# 점 시각화
|
25 |
+
canvas, dot_map = model.visualize_dots()
|
26 |
+
plt.close(canvas.figure)
|
27 |
+
|
28 |
+
return (
|
29 |
+
f"예측된 군중 수: {count:.1f}명",
|
30 |
+
density_map,
|
31 |
+
dot_map
|
32 |
+
)
|
33 |
+
|
34 |
+
with gr.Blocks(title="CLIP-EBC Crowd Counter") as app:
|
35 |
+
gr.Markdown("# CLIP-EBC Crowd Counter")
|
36 |
+
gr.Markdown("이미지를 업로드하여 군중 수를 예측하고 시각화합니다.")
|
37 |
+
|
38 |
+
with gr.Row():
|
39 |
+
input_image = gr.Image(type="numpy", label="입력 이미지")
|
40 |
+
|
41 |
+
with gr.Row():
|
42 |
+
predict_btn = gr.Button("예측", variant="primary")
|
43 |
+
|
44 |
+
with gr.Row():
|
45 |
+
count_text = gr.Textbox(label="예측 결과")
|
46 |
+
|
47 |
+
with gr.Row():
|
48 |
+
with gr.Column():
|
49 |
+
density_output = gr.Image(label="밀도 맵")
|
50 |
+
with gr.Column():
|
51 |
+
dots_output = gr.Image(label="점 시각화")
|
52 |
+
|
53 |
+
predict_btn.click(
|
54 |
+
fn=predict_crowd,
|
55 |
+
inputs=input_image,
|
56 |
+
outputs=[count_text, density_output, dots_output]
|
57 |
+
)
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
app.launch(share=False)
|
assets/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
import os
|
3 |
+
|
4 |
+
def download_required_files():
|
5 |
+
"""Initialize required files from Hugging Face Hub"""
|
6 |
+
try:
|
7 |
+
cache_dir = "assets/"
|
8 |
+
if not os.path.exists(os.path.join(cache_dir, "CLIP_EBC_nwpu_rmse_onnx.onnx")):
|
9 |
+
hf_hub_download(
|
10 |
+
repo_id="PIA-SPACE-LAB/CLIP_EBC_nwpu_rmse_onnx",
|
11 |
+
filename="CLIP_EBC_nwpu_rmse_onnx.onnx",
|
12 |
+
# cache_dir=cache_dir,
|
13 |
+
local_dir=cache_dir
|
14 |
+
)
|
15 |
+
print("Required files downloaded successfully")
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Error downloading required files: {e}")
|
18 |
+
|
19 |
+
def download_required_files2():
|
20 |
+
"""Initialize required files from Hugging Face Hub"""
|
21 |
+
try:
|
22 |
+
cache_dir = "assets/"
|
23 |
+
if not os.path.exists(os.path.join(cache_dir, "CLIP_EBC_nwpu_rmse.pth")):
|
24 |
+
hf_hub_download(
|
25 |
+
repo_id="PIA-SPACE-LAB/CLIP_EBC_nwpu_rmse",
|
26 |
+
filename="CLIP_EBC_nwpu_rmse.pth",
|
27 |
+
# cache_dir=cache_dir,
|
28 |
+
local_dir=cache_dir
|
29 |
+
)
|
30 |
+
print("Required files downloaded successfully")
|
31 |
+
except Exception as e:
|
32 |
+
print(f"Error downloading required files: {e}")
|
33 |
+
|
34 |
+
download_required_files()
|
35 |
+
download_required_files2()
|
configs/reduction_16.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"8":{
|
3 |
+
"qnrf": {
|
4 |
+
"bins": {
|
5 |
+
"fine":[
|
6 |
+
[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
|
7 |
+
[5, 5], [6, 6], [7, 7], [8, "inf"]
|
8 |
+
],
|
9 |
+
"dynamic": [
|
10 |
+
[0, 0], [1, 1], [2, 2], [3, 3],
|
11 |
+
[4, 5], [6, 7], [8, "inf"]
|
12 |
+
],
|
13 |
+
"coarse": [
|
14 |
+
[0, 0], [1, 2], [3, 4], [5, 6], [7, "inf"]
|
15 |
+
]
|
16 |
+
},
|
17 |
+
"anchor_points": {
|
18 |
+
"fine": {
|
19 |
+
"middle": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
20 |
+
"average": [0, 1, 2, 3, 4, 5, 6, 7, 9.23349]
|
21 |
+
},
|
22 |
+
"dynamic": {
|
23 |
+
"middle": [0, 1, 2, 3, 4.5, 6.5, 8],
|
24 |
+
"average": [0, 1, 2, 3, 4.29278, 6.31441, 9.23349]
|
25 |
+
},
|
26 |
+
"coarse": {
|
27 |
+
"middle": [0, 1.5, 3.5, 5.5, 7],
|
28 |
+
"average": [0, 1.14978, 3.27641, 5.30609, 8.11466]
|
29 |
+
}
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|
33 |
+
}
|
configs/reduction_32.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"19": {
|
3 |
+
"qnrf": {
|
4 |
+
"bins": {
|
5 |
+
"fine": [
|
6 |
+
[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
|
7 |
+
[5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
|
8 |
+
[10, 10], [11, 11], [12, 12], [13, 13], [14, 14],
|
9 |
+
[15, 15], [16, 16], [17, 17], [18, 18], [19, "inf"]
|
10 |
+
],
|
11 |
+
"dynamic": [
|
12 |
+
[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
|
13 |
+
[5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
|
14 |
+
[10, 11], [12, 13], [14, 15], [16, 17], [18, "inf"]
|
15 |
+
],
|
16 |
+
"coarse": [
|
17 |
+
[0, 0], [1, 2], [3, 4], [5, 6], [7, 8],
|
18 |
+
[9, 10], [11, 12], [13, 14], [15, 16], [17, 18],
|
19 |
+
[19, "inf"]
|
20 |
+
]
|
21 |
+
},
|
22 |
+
"anchor_points": {
|
23 |
+
"fine": {
|
24 |
+
"middle": [
|
25 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
26 |
+
11, 12, 13, 14, 15, 16, 17, 18, 19
|
27 |
+
],
|
28 |
+
"average": [
|
29 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
30 |
+
11, 12, 13, 14, 15, 16, 17, 18, 23.01897
|
31 |
+
]
|
32 |
+
},
|
33 |
+
"dynamic": {
|
34 |
+
"middle": [
|
35 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.5,
|
36 |
+
12.5, 14.5, 16.5, 18
|
37 |
+
],
|
38 |
+
"average": [
|
39 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10.42903,
|
40 |
+
12.43320, 14.43341, 16.43521, 21.93548
|
41 |
+
]
|
42 |
+
},
|
43 |
+
"coarse": {
|
44 |
+
"middle": [
|
45 |
+
0, 1.5, 3.5, 5.5, 7.5, 9.5,
|
46 |
+
11.5, 13.5, 15.5, 17.5, 19
|
47 |
+
],
|
48 |
+
"average": [
|
49 |
+
0, 1.23498, 3.36108, 5.40298, 7.41406, 9.42356,
|
50 |
+
11.43094, 13.43244, 15.43697, 17.43759, 23.01897
|
51 |
+
]
|
52 |
+
}
|
53 |
+
}
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
configs/reduction_8.json
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"2": {
|
3 |
+
"sha": {
|
4 |
+
"bins": {
|
5 |
+
"fine": [[0, 0], [1, 1], [2, "inf"]]
|
6 |
+
},
|
7 |
+
"anchor_points": {
|
8 |
+
"fine": {
|
9 |
+
"middle": [0, 1, 2],
|
10 |
+
"average": [0, 1, 2.24479]
|
11 |
+
}
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"shb": {
|
15 |
+
"bins": {
|
16 |
+
"fine": [[0, 0], [1, 1], [2, "inf"]]
|
17 |
+
},
|
18 |
+
"anchor_points": {
|
19 |
+
"fine": {
|
20 |
+
"middle": [0, 1, 2],
|
21 |
+
"average": [0, 1, 2.15171]
|
22 |
+
}
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"nwpu": {
|
26 |
+
"bins": {
|
27 |
+
"fine": [[0, 0], [1, 1], [2, "inf"]]
|
28 |
+
},
|
29 |
+
"anchor_points": {
|
30 |
+
"fine": {
|
31 |
+
"middle": [0, 1, 2],
|
32 |
+
"average": [0, 1, 2.10737]
|
33 |
+
}
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"qnrf": {
|
37 |
+
"bins": {
|
38 |
+
"fine": [[0, 0], [1, 1], [2, "inf"]]
|
39 |
+
},
|
40 |
+
"anchor_points": {
|
41 |
+
"fine": {
|
42 |
+
"middle": [0, 1, 2],
|
43 |
+
"average": [0, 1, 2.09296]
|
44 |
+
}
|
45 |
+
}
|
46 |
+
},
|
47 |
+
"jhu": {
|
48 |
+
"bins": {
|
49 |
+
"fine": [[0, 0], [1, 1], [2, "inf"]]
|
50 |
+
},
|
51 |
+
"anchor_points": {
|
52 |
+
"fine": {
|
53 |
+
"middle": [0, 1, 2],
|
54 |
+
"average": [0, 1, 2.18589]
|
55 |
+
}
|
56 |
+
}
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"4": {
|
60 |
+
"sha": {
|
61 |
+
"bins": {
|
62 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
|
63 |
+
},
|
64 |
+
"anchor_points": {
|
65 |
+
"fine": {
|
66 |
+
"middle": [0, 1, 2, 3, 4],
|
67 |
+
"average": [0, 1, 2, 3, 4.29992]
|
68 |
+
}
|
69 |
+
}
|
70 |
+
},
|
71 |
+
"shb": {
|
72 |
+
"bins": {
|
73 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
|
74 |
+
},
|
75 |
+
"anchor_points": {
|
76 |
+
"fine": {
|
77 |
+
"middle": [0, 1, 2, 3, 4],
|
78 |
+
"average": [0, 1, 2, 3, 4.41009]
|
79 |
+
}
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"nwpu": {
|
83 |
+
"bins": {
|
84 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
|
85 |
+
},
|
86 |
+
"anchor_points": {
|
87 |
+
"fine": {
|
88 |
+
"middle": [0, 1, 2, 3, 4],
|
89 |
+
"average": [0, 1, 2, 3, 4.21931]
|
90 |
+
}
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"qnrf": {
|
94 |
+
"bins": {
|
95 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
|
96 |
+
},
|
97 |
+
"anchor_points": {
|
98 |
+
"fine": {
|
99 |
+
"middle": [0, 1, 2, 3, 4],
|
100 |
+
"average": [0, 1, 2, 3, 4.21937]
|
101 |
+
}
|
102 |
+
}
|
103 |
+
},
|
104 |
+
"jhu": {
|
105 |
+
"bins": {
|
106 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
|
107 |
+
},
|
108 |
+
"anchor_points": {
|
109 |
+
"fine": {
|
110 |
+
"middle": [0, 1, 2, 3, 4],
|
111 |
+
"average": [0, 1, 2, 3, 4.24058]
|
112 |
+
}
|
113 |
+
}
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"11": {
|
117 |
+
"qnrf": {
|
118 |
+
"bins": {
|
119 |
+
"fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, "inf"]]
|
120 |
+
},
|
121 |
+
"anchor_points": {
|
122 |
+
"fine": {
|
123 |
+
"middle": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
124 |
+
"average": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
}
|
129 |
+
}
|
custom/airport_color.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
from custom.json2seg import get_segmentation_by_id
|
3 |
+
import random
|
4 |
+
INCHEON = "/home/jungseoik/data/PR/CLIP-EBC/assets/incheon.jpg"
|
5 |
+
COLOR_PAIR = {1: '빨간색', 2: '주황색', 3: '노란색', 4: '초록색', 5: '빨간색', 6: '초록색'}
|
6 |
+
|
7 |
+
def generate_random_color_pair():
|
8 |
+
colors = ['빨간색', '주황색', '노란색', '초록색']
|
9 |
+
return {i: random.choice(colors) for i in range(1, 7)}
|
10 |
+
|
11 |
+
def create_mask(segmentation, img_size, color):
|
12 |
+
mask = Image.new('RGBA', img_size, (0, 0, 0, 0))
|
13 |
+
draw = ImageDraw.Draw(mask)
|
14 |
+
|
15 |
+
polygon = segmentation[0]
|
16 |
+
points = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
|
17 |
+
|
18 |
+
color_map = {
|
19 |
+
'빨간색': (255, 0, 0, 128),
|
20 |
+
'주황색': (255, 165, 0, 128),
|
21 |
+
'노란색': (255, 255, 0, 128),
|
22 |
+
'초록색': (0, 255, 0, 128),
|
23 |
+
'파란색': (0, 0, 255, 128),
|
24 |
+
'보라색': (128, 0, 128, 128)
|
25 |
+
}
|
26 |
+
|
27 |
+
draw.polygon(points, fill=color_map[color])
|
28 |
+
return mask
|
29 |
+
|
30 |
+
def create_all_masks(img_size, region_color_pairs):
|
31 |
+
"""
|
32 |
+
Parameters:
|
33 |
+
- img_size: 이미지 크기
|
34 |
+
- region_color_pairs: Dictionary 형태로 {region_id: color} 매핑
|
35 |
+
예: {1: '빨간색', 2: '초록색', 3: '노란색', ...}
|
36 |
+
"""
|
37 |
+
# 최종 마스크 생성
|
38 |
+
final_mask = Image.new('RGBA', img_size, (0, 0, 0, 0))
|
39 |
+
|
40 |
+
# 입력받은 region_color_pairs에 따라 마스크 생성 및 합성
|
41 |
+
for region_id, color in region_color_pairs.items():
|
42 |
+
segmentation = get_segmentation_by_id(target_id=region_id)
|
43 |
+
region_mask = create_mask(segmentation, img_size, color)
|
44 |
+
final_mask = Image.alpha_composite(final_mask, region_mask)
|
45 |
+
|
46 |
+
return final_mask
|
47 |
+
|
48 |
+
def airport_map_color(color_pairs = COLOR_PAIR):
|
49 |
+
# region_color_pairs = COLOR_PAIR
|
50 |
+
region_color_pairs = generate_random_color_pair()
|
51 |
+
image = Image.open(INCHEON)
|
52 |
+
all_masks = create_all_masks(image.size, region_color_pairs)
|
53 |
+
result = Image.alpha_composite(image.convert('RGBA'), all_masks)
|
54 |
+
return result
|
custom/clip_ebc.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from torchvision.transforms.functional import normalize, to_pil_image
|
5 |
+
from torchvision.transforms import ToTensor, Normalize
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import json
|
8 |
+
from models import get_model
|
9 |
+
from utils import resize_density_map, sliding_window_predict
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
from scipy.ndimage import gaussian_filter
|
13 |
+
from sklearn.cluster import KMeans
|
14 |
+
import datetime
|
15 |
+
from typing import Optional
|
16 |
+
from typing import Union
|
17 |
+
|
18 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
19 |
+
sys.path.append(project_root)
|
20 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
21 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
|
24 |
+
class ClipEBC:
|
25 |
+
"""
|
26 |
+
CLIP-EBC (Efficient Boundary Counting) 이미지 처리 클래스입니다.
|
27 |
+
|
28 |
+
CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한
|
29 |
+
다양한 설정 옵션을 제공합니다.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
truncation (int): 잘라내기 매개변수. 기본값 4.
|
33 |
+
reduction (int): 축소 비율. 기본값 8.
|
34 |
+
granularity (str): 세분화 수준. 기본값 "fine".
|
35 |
+
anchor_points (str): 앵커 포인트 방법. 기본값 "average".
|
36 |
+
model_name (str): CLIP 모델 이름. 기본값 "clip_vit_b_16".
|
37 |
+
input_size (int): 입력 이미지 크기. 기본값 224.
|
38 |
+
window_size (int): 슬라이딩 윈도우 크기. 기본값 224.
|
39 |
+
stride (int): 슬라이딩 윈도우 이동 간격. 기본값 224.
|
40 |
+
prompt_type (str): 프롬프트 유형. 기본값 "word".
|
41 |
+
dataset_name (str): 데이터셋 이름. 기본값 "qnrf".
|
42 |
+
num_vpt (int): 비주얼 프롬프트 토큰 수. 기본값 32.
|
43 |
+
vpt_drop (float): 비주얼 프롬프트 토큰 드롭아웃 비율. 기본값 0.0.
|
44 |
+
deep_vpt (bool): 깊은 비주얼 프롬프트 토큰 사용 여부. 기본값 True.
|
45 |
+
mean (tuple): 정규화를 위한 평균값. 기본값 (0.485, 0.456, 0.406).
|
46 |
+
std (tuple): 정규화를 위한 표준편차값. 기본값 (0.229, 0.224, 0.225).
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self,
|
50 |
+
truncation=4,
|
51 |
+
reduction=8,
|
52 |
+
granularity="fine",
|
53 |
+
anchor_points="average",
|
54 |
+
model_name="clip_vit_b_16",
|
55 |
+
input_size=224,
|
56 |
+
window_size=224,
|
57 |
+
stride=224,
|
58 |
+
prompt_type="word",
|
59 |
+
dataset_name="qnrf",
|
60 |
+
num_vpt=32,
|
61 |
+
vpt_drop=0.,
|
62 |
+
deep_vpt=True,
|
63 |
+
mean=(0.485, 0.456, 0.406),
|
64 |
+
std=(0.229, 0.224, 0.225),
|
65 |
+
config_dir="configs"):
|
66 |
+
"""CLIPEBC 클래스를 설정 매개변수와 함께 초기화합니다."""
|
67 |
+
self.truncation = truncation
|
68 |
+
self.reduction = reduction
|
69 |
+
self.granularity = granularity
|
70 |
+
self.anchor_points_type = anchor_points # 원래 입력값 저장
|
71 |
+
self.model_name = model_name
|
72 |
+
self.input_size = input_size
|
73 |
+
self.window_size = window_size
|
74 |
+
self.stride = stride
|
75 |
+
self.prompt_type = prompt_type
|
76 |
+
self.dataset_name = dataset_name
|
77 |
+
self.num_vpt = num_vpt
|
78 |
+
self.vpt_drop = vpt_drop
|
79 |
+
self.deep_vpt = deep_vpt
|
80 |
+
self.mean = mean
|
81 |
+
self.std = std
|
82 |
+
self.config_dir = config_dir
|
83 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
|
85 |
+
self.bins = None
|
86 |
+
self.anchor_points = None
|
87 |
+
self.model = None
|
88 |
+
|
89 |
+
# 초기 설정 로드 및 모델 초기화
|
90 |
+
self._load_config()
|
91 |
+
self._initialize_model()
|
92 |
+
|
93 |
+
def _load_config(self):
|
94 |
+
"""설정 파일을 로드하고 bins와 anchor_points를 설정합니다."""
|
95 |
+
config_path = os.path.join(self.config_dir, f"reduction_{self.reduction}.json")
|
96 |
+
with open(config_path, "r") as f:
|
97 |
+
config = json.load(f)[str(self.truncation)][self.dataset_name]
|
98 |
+
|
99 |
+
self.bins = config["bins"][self.granularity]
|
100 |
+
self.bins = [(float(b[0]), float(b[1])) for b in self.bins]
|
101 |
+
|
102 |
+
if self.anchor_points_type == "average":
|
103 |
+
self.anchor_points = config["anchor_points"][self.granularity]["average"]
|
104 |
+
else:
|
105 |
+
self.anchor_points = config["anchor_points"][self.granularity]["middle"]
|
106 |
+
self.anchor_points = [float(p) for p in self.anchor_points]
|
107 |
+
|
108 |
+
def _initialize_model(self):
|
109 |
+
"""CLIP 모델을 초기화합니다."""
|
110 |
+
self.model = get_model(
|
111 |
+
backbone=self.model_name,
|
112 |
+
input_size=self.input_size,
|
113 |
+
reduction=self.reduction,
|
114 |
+
bins=self.bins,
|
115 |
+
anchor_points=self.anchor_points,
|
116 |
+
prompt_type=self.prompt_type,
|
117 |
+
num_vpt=self.num_vpt,
|
118 |
+
vpt_drop=self.vpt_drop,
|
119 |
+
deep_vpt=self.deep_vpt
|
120 |
+
)
|
121 |
+
|
122 |
+
ckpt_path = "assets/CLIP_EBC_nwpu_rmse.pth"
|
123 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
124 |
+
self.model.load_state_dict(ckpt)
|
125 |
+
self.model = self.model.to(device)
|
126 |
+
self.model.eval()
|
127 |
+
|
128 |
+
def visualize_density_map(self, alpha: float = 0.5, save: bool = False,
|
129 |
+
save_path: Optional[str] = None):
|
130 |
+
"""
|
131 |
+
현재 저장된 예측 결과를 시각화합니다.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
alpha (float): density map의 투명도 (0~1). 기본값 0.5
|
135 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
136 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
137 |
+
기본값 None
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Tuple[matplotlib.figure.Figure, np.ndarray]:
|
141 |
+
- density map이 오버레이된 matplotlib Figure 객체
|
142 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
143 |
+
Raises:
|
144 |
+
ValueError: density_map 또는 processed_image가 None인 경우 (predict 메서드가 실행되지 않은 경우)
|
145 |
+
"""
|
146 |
+
if self.density_map is None or self.processed_image is None:
|
147 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
148 |
+
|
149 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
150 |
+
ax.imshow(self.processed_image)
|
151 |
+
ax.imshow(self.density_map, cmap="jet", alpha=alpha)
|
152 |
+
ax.axis("off")
|
153 |
+
|
154 |
+
if save:
|
155 |
+
if save_path is None:
|
156 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
157 |
+
save_path = f"crowd_density_{timestamp}.png"
|
158 |
+
|
159 |
+
# 여백 제거하고 저장
|
160 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
161 |
+
print(f"Image saved to: {save_path}")
|
162 |
+
|
163 |
+
fig.canvas.draw()
|
164 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
165 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
166 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
167 |
+
|
168 |
+
return fig , image_from_plot
|
169 |
+
|
170 |
+
def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97,
|
171 |
+
save: bool = False, save_path: Optional[str] = None):
|
172 |
+
"""
|
173 |
+
예측된 군중 위치를 점으로 표시하여 시각화합니다.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
dot_size (int): 점의 크기. 기본값 20
|
177 |
+
sigma (float): Gaussian 필터의 sigma 값. 기본값 1
|
178 |
+
percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97
|
179 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
180 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
181 |
+
기본값 None
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]:
|
185 |
+
- matplotlib figure의 canvas 객체
|
186 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
187 |
+
Raises:
|
188 |
+
ValueError: density_map 또는 processed_image가 None인 경우 (predict 메서드가 실행되지 않은 경우)
|
189 |
+
"""
|
190 |
+
if self.density_map is None or self.processed_image is None:
|
191 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
192 |
+
|
193 |
+
adjusted_pred_count = int(round(self.count))
|
194 |
+
|
195 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
196 |
+
ax.imshow(self.processed_image)
|
197 |
+
|
198 |
+
filtered_density = gaussian_filter(self.density_map, sigma=sigma)
|
199 |
+
|
200 |
+
threshold = np.percentile(filtered_density, percentile)
|
201 |
+
candidate_pixels = np.column_stack(np.where(filtered_density >= threshold))
|
202 |
+
|
203 |
+
if len(candidate_pixels) > adjusted_pred_count:
|
204 |
+
kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10)
|
205 |
+
kmeans.fit(candidate_pixels)
|
206 |
+
head_positions = kmeans.cluster_centers_.astype(int)
|
207 |
+
else:
|
208 |
+
head_positions = candidate_pixels
|
209 |
+
|
210 |
+
y_coords, x_coords = head_positions[:, 0], head_positions[:, 1]
|
211 |
+
ax.scatter(x_coords, y_coords,
|
212 |
+
c='red',
|
213 |
+
s=dot_size,
|
214 |
+
alpha=1.0,
|
215 |
+
edgecolors='white',
|
216 |
+
linewidth=1)
|
217 |
+
|
218 |
+
ax.axis("off")
|
219 |
+
|
220 |
+
if save:
|
221 |
+
if save_path is None:
|
222 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
223 |
+
save_path = f"crowd_dots_{timestamp}.png"
|
224 |
+
|
225 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
226 |
+
print(f"Image saved to: {save_path}")
|
227 |
+
|
228 |
+
# Figure를 numpy 배열로 변환
|
229 |
+
fig.canvas.draw()
|
230 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
231 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
232 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
233 |
+
|
234 |
+
# plt.close(fig)
|
235 |
+
# return image_from_plot
|
236 |
+
return fig.canvas, image_from_plot
|
237 |
+
|
238 |
+
def _process_image(self, image: Union[str, np.ndarray]) -> torch.Tensor:
|
239 |
+
"""
|
240 |
+
이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
image: 입력 이미지. 다음 형식 중 하나여야 합니다:
|
244 |
+
- str: 이미지 파일 경로
|
245 |
+
- np.ndarray: (H, W, 3) 형태의 RGB 이미지
|
246 |
+
- UploadedFile: Streamlit의 업로드된 파일
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
torch.Tensor: 전처리된 이미지 텐서, shape (1, 3, H, W)
|
250 |
+
|
251 |
+
Raises:
|
252 |
+
ValueError: 지원하지 않는 이미지 형식이 입력된 경우
|
253 |
+
Exception: 이미지 파일을 열 수 없는 경우
|
254 |
+
"""
|
255 |
+
to_tensor = ToTensor()
|
256 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
257 |
+
|
258 |
+
# 원본 이미지 저장
|
259 |
+
self.original_image = image
|
260 |
+
|
261 |
+
# 입력 타입에 따른 처리
|
262 |
+
if isinstance(image, str):
|
263 |
+
# 파일 경로인 경우
|
264 |
+
with open(image, "rb") as f:
|
265 |
+
pil_image = Image.open(f).convert("RGB")
|
266 |
+
elif isinstance(image, np.ndarray):
|
267 |
+
# 넘파이 배열인 경우
|
268 |
+
if image.dtype == np.uint8:
|
269 |
+
pil_image = Image.fromarray(image)
|
270 |
+
else:
|
271 |
+
# float 타입인 경우 [0, 1] 범위로 가정하고 변환
|
272 |
+
pil_image = Image.fromarray((image * 255).astype(np.uint8))
|
273 |
+
else:
|
274 |
+
# Streamlit UploadedFile 또는 기타 파일 객체인 경우
|
275 |
+
try:
|
276 |
+
pil_image = Image.open(image).convert("RGB")
|
277 |
+
except Exception as e:
|
278 |
+
raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e
|
279 |
+
|
280 |
+
# 텐서 변환 및 정규화
|
281 |
+
tensor_image = to_tensor(pil_image)
|
282 |
+
normalized_image = normalize(tensor_image)
|
283 |
+
batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W)
|
284 |
+
batched_image = batched_image.to(self.device)
|
285 |
+
|
286 |
+
return batched_image
|
287 |
+
def _post_process_image(self, image):
|
288 |
+
"""이미지 후처리를 수행합니다."""
|
289 |
+
image = normalize(image, mean=(0., 0., 0.),
|
290 |
+
std=(1. / self.std[0], 1. / self.std[1], 1. / self.std[2]))
|
291 |
+
image = normalize(image, mean=(-self.mean[0], -self.mean[1], -self.mean[2]),
|
292 |
+
std=(1., 1., 1.))
|
293 |
+
processed_image = to_pil_image(image.squeeze(0))
|
294 |
+
return processed_image
|
295 |
+
|
296 |
+
@torch.no_grad()
|
297 |
+
def predict(self, image: torch.Tensor) -> Image.Image:
|
298 |
+
"""
|
299 |
+
모델 출력 이미지의 후처리를 수행합니다.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
image (torch.Tensor): 후처리할 이미지 텐서, shape (1, 3, H, W)
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
PIL.Image.Image: 후처리된 PIL 이미지
|
306 |
+
|
307 |
+
Note:
|
308 |
+
이미지 텐서에 대해 정규화를 역변환하고 PIL 이미지 형식으로 변환합니다.
|
309 |
+
self.mean과 self.std 값을 사용하여 원본 이미지의 스케일로 복원합니다.
|
310 |
+
"""
|
311 |
+
processed_image = self._process_image(image)
|
312 |
+
image_height, image_width = processed_image.shape[-2:]
|
313 |
+
processed_image = processed_image.to(self.device)
|
314 |
+
|
315 |
+
pred_density = sliding_window_predict(self.model, processed_image,
|
316 |
+
self.window_size, self.stride)
|
317 |
+
pred_count = pred_density.sum().item()
|
318 |
+
resized_pred_density = resize_density_map(pred_density,
|
319 |
+
(image_height, image_width)).cpu()
|
320 |
+
|
321 |
+
self.processed_image = self._post_process_image(processed_image)
|
322 |
+
self.density_map = resized_pred_density.squeeze().numpy()
|
323 |
+
self.count = pred_count
|
324 |
+
|
325 |
+
return pred_count
|
326 |
+
|
327 |
+
def crowd_count(self):
|
328 |
+
"""
|
329 |
+
가장 최근 예측의 군중 수를 반환합니다.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
float: 예측된 군중 수
|
333 |
+
None: 아직 예측이 수행되지 않은 경우
|
334 |
+
"""
|
335 |
+
return self.count
|
336 |
+
|
337 |
+
def get_density_map(self):
|
338 |
+
"""
|
339 |
+
가장 최근 예측의 밀도 맵을 반환합니다.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
numpy.ndarray: 밀도 맵
|
343 |
+
None: 아직 예측이 수행되지 않은 경우
|
344 |
+
"""
|
345 |
+
return self.density_map
|
346 |
+
|
custom/clip_ebc_onnx.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
from typing import Union, Tuple, Optional
|
7 |
+
from PIL import Image
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from torchvision.transforms import ToTensor, Normalize
|
10 |
+
from torchvision.transforms.functional import normalize, to_pil_image
|
11 |
+
import json
|
12 |
+
import datetime
|
13 |
+
from scipy.ndimage import gaussian_filter
|
14 |
+
from sklearn.cluster import KMeans
|
15 |
+
import assets
|
16 |
+
|
17 |
+
# 프로젝트 루트 디렉토리 설정
|
18 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
19 |
+
sys.path.append(project_root)
|
20 |
+
|
21 |
+
class ClipEBCOnnx:
|
22 |
+
"""
|
23 |
+
CLIP-EBC (Efficient Boundary Counting) ONNX 버전 이미지 처리 클래스입니다.
|
24 |
+
|
25 |
+
ONNX로 변환된 CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한
|
26 |
+
다양한 설정 옵션을 제공합니다.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
onnx_model_path="assets/CLIP_EBC_nwpu_rmse_onnx.onnx",
|
31 |
+
truncation=4,
|
32 |
+
reduction=8,
|
33 |
+
granularity="fine",
|
34 |
+
anchor_points="average",
|
35 |
+
input_size=224,
|
36 |
+
window_size=224,
|
37 |
+
stride=224,
|
38 |
+
dataset_name="qnrf",
|
39 |
+
mean=(0.485, 0.456, 0.406),
|
40 |
+
std=(0.229, 0.224, 0.225),
|
41 |
+
config_dir="configs"):
|
42 |
+
"""CLIPEBC ONNX 클래스를 설정 매개변수와 함께 초기화합니다."""
|
43 |
+
self.onnx_model_path = onnx_model_path
|
44 |
+
self.truncation = truncation
|
45 |
+
self.reduction = reduction
|
46 |
+
self.granularity = granularity
|
47 |
+
self.anchor_points_type = anchor_points
|
48 |
+
self.input_size = input_size
|
49 |
+
self.window_size = window_size
|
50 |
+
self.stride = stride
|
51 |
+
self.dataset_name = dataset_name
|
52 |
+
self.mean = mean
|
53 |
+
self.std = std
|
54 |
+
self.config_dir = config_dir
|
55 |
+
|
56 |
+
# 결과 저장용 변수 초기화
|
57 |
+
self.density_map = None
|
58 |
+
self.processed_image = None
|
59 |
+
self.count = None
|
60 |
+
self.original_image = None
|
61 |
+
|
62 |
+
# ONNX 추론 세션 설정
|
63 |
+
self.session_options = ort.SessionOptions()
|
64 |
+
self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
65 |
+
|
66 |
+
# 가능한 경우 GPU 사용
|
67 |
+
self.providers = []
|
68 |
+
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
69 |
+
self.providers.append('CUDAExecutionProvider')
|
70 |
+
self.providers.append('CPUExecutionProvider')
|
71 |
+
|
72 |
+
# ONNX 런타임 세션 초기화
|
73 |
+
print(f"ONNX 모델 로드 중: {self.onnx_model_path}")
|
74 |
+
self.session = ort.InferenceSession(
|
75 |
+
self.onnx_model_path,
|
76 |
+
sess_options=self.session_options,
|
77 |
+
providers=self.providers
|
78 |
+
)
|
79 |
+
|
80 |
+
# 모델의 입력 및 출력 이름 가져오기
|
81 |
+
self.input_name = self.session.get_inputs()[0].name
|
82 |
+
self.output_name = self.session.get_outputs()[0].name
|
83 |
+
|
84 |
+
print(f"입력 이름: {self.input_name}, 형태: {self.session.get_inputs()[0].shape}")
|
85 |
+
print(f"출력 이름: {self.output_name}, 형태: {self.session.get_outputs()[0].shape}")
|
86 |
+
print(f"실행 제공자: {self.providers}")
|
87 |
+
|
88 |
+
def _process_image(self, image: Union[str, np.ndarray]) -> np.ndarray:
|
89 |
+
"""
|
90 |
+
이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
image: 입력 이미지. 다음 형식 중 하나여야 합니다:
|
94 |
+
- str: 이미지 파일 경로
|
95 |
+
- np.ndarray: (H, W, 3) 형태의 RGB 이미지
|
96 |
+
- UploadedFile: Streamlit의 업로드된 파일
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
np.ndarray: 전처리된 이미지 배열, shape (1, 3, H, W)
|
100 |
+
"""
|
101 |
+
to_tensor = ToTensor()
|
102 |
+
normalize = Normalize(mean=self.mean, std=self.std)
|
103 |
+
|
104 |
+
# 원본 이미지 저장
|
105 |
+
self.original_image = image
|
106 |
+
|
107 |
+
# 입력 타입에 따른 처리
|
108 |
+
if isinstance(image, str):
|
109 |
+
# 파일 경로인 경우
|
110 |
+
with open(image, "rb") as f:
|
111 |
+
pil_image = Image.open(f).convert("RGB")
|
112 |
+
elif isinstance(image, np.ndarray):
|
113 |
+
# 넘파이 배열인 경우
|
114 |
+
if image.dtype == np.uint8:
|
115 |
+
pil_image = Image.fromarray(image)
|
116 |
+
else:
|
117 |
+
# float 타입인 경우 [0, 1] 범위로 가정하고 변환
|
118 |
+
pil_image = Image.fromarray((image * 255).astype(np.uint8))
|
119 |
+
else:
|
120 |
+
# Streamlit UploadedFile 또는 기타 파일 객체인 경우
|
121 |
+
try:
|
122 |
+
pil_image = Image.open(image).convert("RGB")
|
123 |
+
except Exception as e:
|
124 |
+
raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e
|
125 |
+
|
126 |
+
# 텐서 변환 및 정규화
|
127 |
+
tensor_image = to_tensor(pil_image)
|
128 |
+
normalized_image = normalize(tensor_image)
|
129 |
+
batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W)
|
130 |
+
|
131 |
+
# numpy로 변환
|
132 |
+
numpy_image = batched_image.numpy()
|
133 |
+
|
134 |
+
return numpy_image
|
135 |
+
|
136 |
+
def _post_process_image(self, image_tensor):
|
137 |
+
"""이미지 텐서를 PIL 이미지로 변환합니다."""
|
138 |
+
# NumPy 배열을 PyTorch 텐서로 변환
|
139 |
+
if isinstance(image_tensor, np.ndarray):
|
140 |
+
image_tensor = torch.from_numpy(image_tensor)
|
141 |
+
|
142 |
+
# 정규화 역변환
|
143 |
+
image = normalize(
|
144 |
+
image_tensor,
|
145 |
+
mean=[0., 0., 0.],
|
146 |
+
std=[1./self.std[0], 1./self.std[1], 1./self.std[2]]
|
147 |
+
)
|
148 |
+
|
149 |
+
image = normalize(
|
150 |
+
image,
|
151 |
+
mean=[-self.mean[0], -self.mean[1], -self.mean[2]],
|
152 |
+
std=[1., 1., 1.]
|
153 |
+
)
|
154 |
+
|
155 |
+
# 배치 차원 제거 및 PIL 이미지로 변환
|
156 |
+
processed_image = to_pil_image(image.squeeze(0))
|
157 |
+
return processed_image
|
158 |
+
|
159 |
+
def sliding_window_predict(self, image: np.ndarray, window_size: Union[int, Tuple[int, int]],
|
160 |
+
stride: Union[int, Tuple[int, int]]) -> np.ndarray:
|
161 |
+
"""
|
162 |
+
슬라이딩 윈도우 방식으로 이미지 예측을 수행합니다. 겹치는 영역은 평균값을 사용합니다.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
image (np.ndarray): 형태가 (1, 3, H, W)인 이미지 배열
|
166 |
+
window_size (int or tuple): 윈도우 크기
|
167 |
+
stride (int or tuple): 윈도우 이동 간격
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
np.ndarray: 예측된 밀도 맵
|
171 |
+
"""
|
172 |
+
# 입력 검증
|
173 |
+
assert len(image.shape) == 4, f"이미지는 4차원 배열이어야 합니다. (1, C, H, W), 현재: {image.shape}"
|
174 |
+
|
175 |
+
# 윈도우 크기와 스트라이드 설정
|
176 |
+
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
|
177 |
+
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
|
178 |
+
window_size = tuple(window_size)
|
179 |
+
stride = tuple(stride)
|
180 |
+
|
181 |
+
# 검증
|
182 |
+
assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, \
|
183 |
+
f"윈도우 크기는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {window_size}"
|
184 |
+
assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, \
|
185 |
+
f"스트라이드는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {stride}"
|
186 |
+
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], \
|
187 |
+
f"스트라이드는 윈도우 크기보다 작아야 합니다. 현재: {stride}와 {window_size}"
|
188 |
+
|
189 |
+
image_height, image_width = image.shape[-2:]
|
190 |
+
window_height, window_width = window_size
|
191 |
+
stride_height, stride_width = stride
|
192 |
+
|
193 |
+
# 슬라이딩 윈도우 수 계산
|
194 |
+
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
|
195 |
+
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
|
196 |
+
|
197 |
+
# 윈도우 추출
|
198 |
+
windows = []
|
199 |
+
window_positions = []
|
200 |
+
for i in range(num_rows):
|
201 |
+
for j in range(num_cols):
|
202 |
+
x_start, y_start = i * stride_height, j * stride_width
|
203 |
+
x_end, y_end = x_start + window_height, y_start + window_width
|
204 |
+
|
205 |
+
# 이미지 경계 처리
|
206 |
+
if x_end > image_height:
|
207 |
+
x_start, x_end = image_height - window_height, image_height
|
208 |
+
if y_end > image_width:
|
209 |
+
y_start, y_end = image_width - window_width, image_width
|
210 |
+
|
211 |
+
window = image[:, :, x_start:x_end, y_start:y_end]
|
212 |
+
windows.append(window)
|
213 |
+
window_positions.append((x_start, y_start, x_end, y_end))
|
214 |
+
|
215 |
+
# 배치 단위로 추론
|
216 |
+
all_preds = []
|
217 |
+
max_batch_size = 8
|
218 |
+
|
219 |
+
for start_idx in range(0, len(windows), max_batch_size):
|
220 |
+
end_idx = min(start_idx + max_batch_size, len(windows))
|
221 |
+
batch_windows = np.vstack(windows[start_idx:end_idx]) # (batch_size, 3, h, w)
|
222 |
+
|
223 |
+
# ONNX 추론
|
224 |
+
ort_inputs = {self.input_name: batch_windows}
|
225 |
+
batch_preds = self.session.run([self.output_name], ort_inputs)[0]
|
226 |
+
|
227 |
+
# Debug 정보
|
228 |
+
# print(f"배치 입력 형태: {batch_windows.shape}, 배치 출력 형태: {batch_preds.shape}")
|
229 |
+
|
230 |
+
all_preds.extend([batch_preds[i:i+1] for i in range(batch_preds.shape[0])])
|
231 |
+
|
232 |
+
# 예측 결과를 numpy 배열로 변환
|
233 |
+
preds = np.concatenate(all_preds, axis=0)
|
234 |
+
|
235 |
+
# 출력 밀도 맵 조립
|
236 |
+
pred_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32)
|
237 |
+
count_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32)
|
238 |
+
|
239 |
+
idx = 0
|
240 |
+
for i in range(num_rows):
|
241 |
+
for j in range(num_cols):
|
242 |
+
x_start, y_start, x_end, y_end = window_positions[idx]
|
243 |
+
|
244 |
+
# 출력 영역 계산 (reduction 고려)
|
245 |
+
x_start_out = x_start // self.reduction
|
246 |
+
y_start_out = y_start // self.reduction
|
247 |
+
x_end_out = x_end // self.reduction
|
248 |
+
y_end_out = y_end // self.reduction
|
249 |
+
|
250 |
+
pred_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += preds[idx]
|
251 |
+
count_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += 1.
|
252 |
+
idx += 1
|
253 |
+
|
254 |
+
# 겹치는 영역 평균 계산
|
255 |
+
pred_map /= count_map
|
256 |
+
|
257 |
+
return pred_map
|
258 |
+
|
259 |
+
def resize_density_map(self, density_map: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
260 |
+
"""
|
261 |
+
밀도 맵의 크기를 조정합니다. 총합은 보존됩니다.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
density_map: 형태가 (C, H, W)인 밀도 맵
|
265 |
+
target_size: 목표 크기 (H', W')
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
np.ndarray: 크기가 조정된 밀도 맵
|
269 |
+
"""
|
270 |
+
from PIL import Image
|
271 |
+
import torch.nn.functional as F
|
272 |
+
import torch
|
273 |
+
|
274 |
+
# numpy를 torch로 변환
|
275 |
+
if isinstance(density_map, np.ndarray):
|
276 |
+
density_map = torch.from_numpy(density_map)
|
277 |
+
|
278 |
+
# 배치 차원 추가
|
279 |
+
if density_map.dim() == 3:
|
280 |
+
density_map = density_map.unsqueeze(0) # (1, C, H, W)
|
281 |
+
|
282 |
+
current_size = density_map.shape[2:]
|
283 |
+
|
284 |
+
if current_size[0] == target_size[0] and current_size[1] == target_size[1]:
|
285 |
+
return density_map.squeeze(0).numpy()
|
286 |
+
|
287 |
+
# 원본 밀도 맵의 총합 계산
|
288 |
+
original_sum = density_map.sum()
|
289 |
+
|
290 |
+
# 크기 조정 (쌍선형 보간)
|
291 |
+
resized_map = F.interpolate(
|
292 |
+
density_map,
|
293 |
+
size=target_size,
|
294 |
+
mode='bilinear',
|
295 |
+
align_corners=False
|
296 |
+
)
|
297 |
+
|
298 |
+
# 총합 보존을 위한 스케일링
|
299 |
+
if resized_map.sum() > 0: # 0으로 나누기 방지
|
300 |
+
resized_map = resized_map * (original_sum / resized_map.sum())
|
301 |
+
|
302 |
+
return resized_map.squeeze(0).numpy()
|
303 |
+
|
304 |
+
def predict(self, image: Union[str, np.ndarray]) -> float:
|
305 |
+
"""
|
306 |
+
이미지에서 군중 계수 예측을 수행합니다.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
image: 입력 이미지 (경로, 넘파이 배열, 또는 업로드된 파일)
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
float: 예측된 사람 수
|
313 |
+
"""
|
314 |
+
# 이미지 전처리
|
315 |
+
processed_image = self._process_image(image)
|
316 |
+
image_height, image_width = processed_image.shape[-2:]
|
317 |
+
|
318 |
+
# 슬라이딩 윈도우 예측
|
319 |
+
pred_density = self.sliding_window_predict(
|
320 |
+
processed_image,
|
321 |
+
self.window_size,
|
322 |
+
self.stride
|
323 |
+
)
|
324 |
+
|
325 |
+
# 예측 결과 저장
|
326 |
+
pred_count = pred_density.sum()
|
327 |
+
|
328 |
+
# 원본 이미지 크기로 밀도 맵 조정
|
329 |
+
resized_pred_density = self.resize_density_map(
|
330 |
+
pred_density,
|
331 |
+
(image_height, image_width)
|
332 |
+
)
|
333 |
+
|
334 |
+
# 결과 저장
|
335 |
+
self.processed_image = self._post_process_image(processed_image)
|
336 |
+
self.density_map = resized_pred_density.squeeze()
|
337 |
+
self.count = pred_count
|
338 |
+
|
339 |
+
return pred_count
|
340 |
+
|
341 |
+
def visualize_density_map(self, alpha: float = 0.5, save: bool = False,
|
342 |
+
save_path: Optional[str] = None):
|
343 |
+
"""
|
344 |
+
현재 저장된 예측 결과를 시각화합니다.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
alpha (float): density map의 투명도 (0~1). 기본값 0.5
|
348 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
349 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
350 |
+
기본값 None
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
Tuple[matplotlib.figure.Figure, np.ndarray]:
|
354 |
+
- density map이 오버레이된 matplotlib Figure 객체
|
355 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
356 |
+
"""
|
357 |
+
if self.density_map is None or self.processed_image is None:
|
358 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
359 |
+
|
360 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
361 |
+
ax.imshow(self.processed_image)
|
362 |
+
ax.imshow(self.density_map, cmap="jet", alpha=alpha)
|
363 |
+
ax.axis("off")
|
364 |
+
plt.title(f"Count: {self.count:.1f}")
|
365 |
+
|
366 |
+
if save:
|
367 |
+
if save_path is None:
|
368 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
369 |
+
save_path = f"crowd_density_{timestamp}.png"
|
370 |
+
|
371 |
+
# 여백 제거하고 저장
|
372 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
373 |
+
print(f"이미지 저장 완료: {save_path}")
|
374 |
+
|
375 |
+
fig.canvas.draw()
|
376 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
377 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
378 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
379 |
+
|
380 |
+
return fig, image_from_plot
|
381 |
+
|
382 |
+
def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97,
|
383 |
+
save: bool = False, save_path: Optional[str] = None):
|
384 |
+
"""
|
385 |
+
예측된 군중 위치를 점으로 표시하여 시각화합니다.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
dot_size (int): 점의 크기. 기본값 20
|
389 |
+
sigma (float): Gaussian 필터의 sigma 값. 기본값 1
|
390 |
+
percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97
|
391 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
392 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
393 |
+
기본값 None
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]:
|
397 |
+
- matplotlib figure의 canvas 객체
|
398 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
399 |
+
"""
|
400 |
+
if self.density_map is None or self.processed_image is None:
|
401 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
402 |
+
|
403 |
+
adjusted_pred_count = int(round(self.count))
|
404 |
+
|
405 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
406 |
+
ax.imshow(self.processed_image)
|
407 |
+
|
408 |
+
filtered_density = gaussian_filter(self.density_map, sigma=sigma)
|
409 |
+
|
410 |
+
threshold = np.percentile(filtered_density, percentile)
|
411 |
+
candidate_pixels = np.column_stack(np.where(filtered_density >= threshold))
|
412 |
+
|
413 |
+
if len(candidate_pixels) > adjusted_pred_count:
|
414 |
+
kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10)
|
415 |
+
kmeans.fit(candidate_pixels)
|
416 |
+
head_positions = kmeans.cluster_centers_.astype(int)
|
417 |
+
else:
|
418 |
+
head_positions = candidate_pixels
|
419 |
+
|
420 |
+
y_coords, x_coords = head_positions[:, 0], head_positions[:, 1]
|
421 |
+
ax.scatter(x_coords, y_coords,
|
422 |
+
c='red',
|
423 |
+
s=dot_size,
|
424 |
+
alpha=1.0,
|
425 |
+
edgecolors='white',
|
426 |
+
linewidth=1)
|
427 |
+
|
428 |
+
ax.axis("off")
|
429 |
+
plt.title(f"Count: {self.count:.1f}")
|
430 |
+
|
431 |
+
if save:
|
432 |
+
if save_path is None:
|
433 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
434 |
+
save_path = f"crowd_dots_{timestamp}.png"
|
435 |
+
|
436 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
437 |
+
print(f"이미지 저장 완료: {save_path}")
|
438 |
+
|
439 |
+
# Figure를 numpy 배열로 변환
|
440 |
+
fig.canvas.draw()
|
441 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
442 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
443 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
444 |
+
|
445 |
+
return fig.canvas, image_from_plot
|
446 |
+
|
447 |
+
def crowd_count(self):
|
448 |
+
"""
|
449 |
+
가장 최근 예측의 군중 수를 반환합니다.
|
450 |
+
|
451 |
+
Returns:
|
452 |
+
float: 예측된 군중 수
|
453 |
+
None: 아직 예측이 수행되지 않은 경우
|
454 |
+
"""
|
455 |
+
return self.count
|
456 |
+
|
457 |
+
def get_density_map(self):
|
458 |
+
"""
|
459 |
+
가장 최근 예측의 밀도 맵을 반환합니다.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
numpy.ndarray: 밀도 맵
|
463 |
+
None: 아직 예측이 수행되지 않은 경우
|
464 |
+
"""
|
465 |
+
return self.density_map
|
custom/clip_ebc_tensorrt.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import tensorrt as trt
|
6 |
+
from typing import Union, Tuple, Optional
|
7 |
+
from PIL import Image
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from torchvision.transforms import ToTensor, Normalize
|
10 |
+
from torchvision.transforms.functional import normalize, to_pil_image
|
11 |
+
import json
|
12 |
+
import datetime
|
13 |
+
from scipy.ndimage import gaussian_filter
|
14 |
+
from sklearn.cluster import KMeans
|
15 |
+
import assets
|
16 |
+
|
17 |
+
# 프로젝트 루트 디렉토리 설정
|
18 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
19 |
+
sys.path.append(project_root)
|
20 |
+
|
21 |
+
class ClipEBCTensorRT:
|
22 |
+
"""
|
23 |
+
CLIP-EBC (Efficient Boundary Counting) TensorRT 버전 이미지 처리 클래스입니다.
|
24 |
+
|
25 |
+
TensorRT로 변환된 CLIP 모델을 사용하여 이미지를 처리하며, 슬라이딩 윈도우 예측 기능을 포함한
|
26 |
+
다양한 설정 옵션을 제공합니다.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
engine_path="assets/CLIP_EBC_nwpu_rmse_tensorrt.trt",
|
31 |
+
truncation=4,
|
32 |
+
reduction=8,
|
33 |
+
granularity="fine",
|
34 |
+
anchor_points="average",
|
35 |
+
input_size=224,
|
36 |
+
window_size=224,
|
37 |
+
stride=224,
|
38 |
+
dataset_name="qnrf",
|
39 |
+
mean=(0.485, 0.456, 0.406),
|
40 |
+
std=(0.229, 0.224, 0.225),
|
41 |
+
config_dir ="configs"):
|
42 |
+
"""CLIPEBC TensorRT 클래스를 설정 매개변수와 함께 초기화합니다."""
|
43 |
+
self.engine_path = engine_path
|
44 |
+
self.truncation = truncation
|
45 |
+
self.reduction = reduction
|
46 |
+
self.granularity = granularity
|
47 |
+
self.anchor_points_type = anchor_points
|
48 |
+
self.input_size = input_size
|
49 |
+
self.window_size = window_size
|
50 |
+
self.stride = stride
|
51 |
+
self.dataset_name = dataset_name
|
52 |
+
self.mean = mean
|
53 |
+
self.std = std
|
54 |
+
self.config_dir = config_dir
|
55 |
+
|
56 |
+
# 결과 저장용 변수 초기화
|
57 |
+
self.density_map = None
|
58 |
+
self.processed_image = None
|
59 |
+
self.count = None
|
60 |
+
self.original_image = None
|
61 |
+
|
62 |
+
# TensorRT 엔진 로드
|
63 |
+
print(f"TensorRT 엔진 로드 중: {self.engine_path}")
|
64 |
+
self._load_engine()
|
65 |
+
|
66 |
+
# 입력 및 출력 이름 설정
|
67 |
+
self.input_name = "input"
|
68 |
+
self.output_name = "output"
|
69 |
+
|
70 |
+
print(f"TensorRT 엔진 초기화 완료")
|
71 |
+
|
72 |
+
def _load_engine(self):
|
73 |
+
"""TensorRT 엔진을 로드합니다."""
|
74 |
+
# TensorRT 로거 생성
|
75 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
76 |
+
|
77 |
+
# 런타임 생성
|
78 |
+
self.runtime = trt.Runtime(TRT_LOGGER)
|
79 |
+
|
80 |
+
# 엔진 파일 로드
|
81 |
+
with open(self.engine_path, 'rb') as f:
|
82 |
+
engine_data = f.read()
|
83 |
+
|
84 |
+
# 직렬화된 엔진에서 엔진 생성
|
85 |
+
self.engine = self.runtime.deserialize_cuda_engine(engine_data)
|
86 |
+
|
87 |
+
# 실행 컨텍스트 생성
|
88 |
+
self.context = self.engine.create_execution_context()
|
89 |
+
|
90 |
+
# TensorRT 10.x에서는 input_binding/output_binding 대신 네트워크 구조를 확인
|
91 |
+
# 입력과 출력을 가져오는 방법이 변경됨
|
92 |
+
self.num_io_tensors = self.engine.num_io_tensors
|
93 |
+
|
94 |
+
# 입력과 출력 텐서 이름 찾기
|
95 |
+
self.input_tensor_names = []
|
96 |
+
self.output_tensor_names = []
|
97 |
+
|
98 |
+
print(f"TensorRT 엔진에서 {self.num_io_tensors}개의 IO 텐서를 찾았습니다")
|
99 |
+
|
100 |
+
for i in range(self.num_io_tensors):
|
101 |
+
name = self.engine.get_tensor_name(i)
|
102 |
+
is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
103 |
+
|
104 |
+
if is_input:
|
105 |
+
self.input_tensor_names.append(name)
|
106 |
+
else:
|
107 |
+
self.output_tensor_names.append(name)
|
108 |
+
|
109 |
+
# 입력과 출력 이름 설정
|
110 |
+
if not self.input_tensor_names:
|
111 |
+
raise ValueError("엔진에서 입력 텐서를 찾을 수 없습니다.")
|
112 |
+
if not self.output_tensor_names:
|
113 |
+
raise ValueError("엔진에서 출력 텐서를 찾을 수 없습니다.")
|
114 |
+
|
115 |
+
# 기본 입력 및 출력 이름 설정
|
116 |
+
self.input_name = self.input_tensor_names[0]
|
117 |
+
self.output_name = self.output_tensor_names[0]
|
118 |
+
|
119 |
+
# 입출력 형태 추출
|
120 |
+
self.input_shape = self.engine.get_tensor_shape(self.input_name)
|
121 |
+
self.output_shape = self.engine.get_tensor_shape(self.output_name)
|
122 |
+
|
123 |
+
print(f"입력 이름: {self.input_name}, 형태: {self.input_shape}")
|
124 |
+
print(f"출력 이름: {self.output_name}, 형태: {self.output_shape}")
|
125 |
+
|
126 |
+
def _process_image(self, image: Union[str, np.ndarray]) -> np.ndarray:
|
127 |
+
"""
|
128 |
+
이미지를 전처리합니다. 이미지 경로, 넘파이 배열, Streamlit UploadedFile 모두 처리 가능합니다.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
image: 입력 이미지. 다음 형식 �� 하나여야 합니다:
|
132 |
+
- str: 이미지 파일 경로
|
133 |
+
- np.ndarray: (H, W, 3) 형태의 RGB 이미지
|
134 |
+
- UploadedFile: Streamlit의 업로드된 파일
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
np.ndarray: 전처리된 이미지 배열, shape (1, 3, H, W)
|
138 |
+
"""
|
139 |
+
to_tensor = ToTensor()
|
140 |
+
normalize = Normalize(mean=self.mean, std=self.std)
|
141 |
+
|
142 |
+
# 원본 이미지 저장
|
143 |
+
self.original_image = image
|
144 |
+
|
145 |
+
# 입력 타입에 따른 처리
|
146 |
+
if isinstance(image, str):
|
147 |
+
# 파일 경로인 경우
|
148 |
+
with open(image, "rb") as f:
|
149 |
+
pil_image = Image.open(f).convert("RGB")
|
150 |
+
elif isinstance(image, np.ndarray):
|
151 |
+
# 넘파이 배열인 경우
|
152 |
+
if image.dtype == np.uint8:
|
153 |
+
pil_image = Image.fromarray(image)
|
154 |
+
else:
|
155 |
+
# float 타입인 경우 [0, 1] 범위로 가정하고 변환
|
156 |
+
pil_image = Image.fromarray((image * 255).astype(np.uint8))
|
157 |
+
else:
|
158 |
+
# Streamlit UploadedFile 또는 기타 파일 객체인 경우
|
159 |
+
try:
|
160 |
+
pil_image = Image.open(image).convert("RGB")
|
161 |
+
except Exception as e:
|
162 |
+
raise ValueError(f"지원하지 않는 이미지 형식입니다: {type(image)}") from e
|
163 |
+
|
164 |
+
# 텐서 변환 및 정규화
|
165 |
+
tensor_image = to_tensor(pil_image)
|
166 |
+
normalized_image = normalize(tensor_image)
|
167 |
+
batched_image = normalized_image.unsqueeze(0) # (1, 3, H, W)
|
168 |
+
|
169 |
+
# numpy로 변환
|
170 |
+
numpy_image = batched_image.numpy()
|
171 |
+
|
172 |
+
return numpy_image
|
173 |
+
|
174 |
+
def _post_process_image(self, image_tensor):
|
175 |
+
"""이미지 텐서를 PIL 이미지로 변환합니다."""
|
176 |
+
# NumPy 배열을 PyTorch 텐서로 변환
|
177 |
+
if isinstance(image_tensor, np.ndarray):
|
178 |
+
image_tensor = torch.from_numpy(image_tensor)
|
179 |
+
|
180 |
+
# 정규화 역변환
|
181 |
+
image = normalize(
|
182 |
+
image_tensor,
|
183 |
+
mean=[0., 0., 0.],
|
184 |
+
std=[1./self.std[0], 1./self.std[1], 1./self.std[2]]
|
185 |
+
)
|
186 |
+
|
187 |
+
image = normalize(
|
188 |
+
image,
|
189 |
+
mean=[-self.mean[0], -self.mean[1], -self.mean[2]],
|
190 |
+
std=[1., 1., 1.]
|
191 |
+
)
|
192 |
+
|
193 |
+
# 배치 차원 제거 및 PIL 이미지로 변환
|
194 |
+
processed_image = to_pil_image(image.squeeze(0))
|
195 |
+
return processed_image
|
196 |
+
def _infer_batch(self, batch_input):
|
197 |
+
"""
|
198 |
+
TensorRT 엔진을 사용하여 배치 추론을 수행합니다. (수정 버전)
|
199 |
+
"""
|
200 |
+
import pycuda.driver as cuda
|
201 |
+
import pycuda.autoinit
|
202 |
+
import numpy as np
|
203 |
+
|
204 |
+
batch_size = batch_input.shape[0]
|
205 |
+
|
206 |
+
# 입력의 형태와 데이터 타입 확인
|
207 |
+
input_shape = (batch_size, 3, self.input_size, self.input_size)
|
208 |
+
print(f"입력 배치 형태: {batch_input.shape}, 데이터 타입: {batch_input.dtype}")
|
209 |
+
|
210 |
+
# 입력 형태 검증
|
211 |
+
if batch_input.shape != input_shape:
|
212 |
+
print(f"경고: 입력 형태 불일치. 예상: {input_shape}, 실제: {batch_input.shape}")
|
213 |
+
# 필요시 형태 수정
|
214 |
+
batch_input = np.resize(batch_input, input_shape)
|
215 |
+
|
216 |
+
# 데이터 타입 검증
|
217 |
+
if batch_input.dtype != np.float32:
|
218 |
+
print(f"경고: 입력 데이터 타입 불일치. float32로 변환합니다.")
|
219 |
+
batch_input = batch_input.astype(np.float32)
|
220 |
+
|
221 |
+
# 동적 배치 크기 설정
|
222 |
+
self.context.set_input_shape(self.input_name, input_shape)
|
223 |
+
|
224 |
+
# 출력 형태 가져오기
|
225 |
+
output_shape = self.context.get_tensor_shape(self.output_name)
|
226 |
+
output_shape = tuple(output_shape) # 튜플로 변환하여 안전성 보장
|
227 |
+
print(f"출력 형태: {output_shape}")
|
228 |
+
|
229 |
+
# -1 값을 실제 배치 크기로 대체
|
230 |
+
if output_shape[0] == -1:
|
231 |
+
output_shape = (batch_size,) + output_shape[1:]
|
232 |
+
|
233 |
+
# 출력 버퍼 준비
|
234 |
+
output = np.empty(output_shape, dtype=np.float32)
|
235 |
+
|
236 |
+
# 호스트 메모리 준비 (페이지 잠금 메모리 사용)
|
237 |
+
h_input = cuda.pagelocked_empty(batch_input.shape, dtype=np.float32)
|
238 |
+
h_output = cuda.pagelocked_empty(output_shape, dtype=np.float32)
|
239 |
+
|
240 |
+
# 입력 데이터 복사
|
241 |
+
np.copyto(h_input, batch_input)
|
242 |
+
|
243 |
+
# 디바이스 메모리 할당
|
244 |
+
d_input = cuda.mem_alloc(h_input.nbytes)
|
245 |
+
d_output = cuda.mem_alloc(h_output.nbytes)
|
246 |
+
|
247 |
+
# CUDA 스트림 생성
|
248 |
+
stream = cuda.Stream()
|
249 |
+
|
250 |
+
try:
|
251 |
+
# 메모리 복사 (호스트 -> 디바이스)
|
252 |
+
cuda.memcpy_htod_async(d_input, h_input, stream)
|
253 |
+
|
254 |
+
# 텐서 주소 설정
|
255 |
+
self.context.set_tensor_address(self.input_name, int(d_input))
|
256 |
+
self.context.set_tensor_address(self.output_name, int(d_output))
|
257 |
+
|
258 |
+
# 디버깅 정보 (메모리 주소)
|
259 |
+
print(f"입력 메모리 주소: {int(d_input)}, 출력 메모리 주소: {int(d_output)}")
|
260 |
+
|
261 |
+
# 실행
|
262 |
+
success = self.context.execute_async_v3(stream_handle=stream.handle)
|
263 |
+
if not success:
|
264 |
+
print("TensorRT 실행 실패")
|
265 |
+
return None
|
266 |
+
|
267 |
+
# 메모리 복사 (디바이스 -> 호스트)
|
268 |
+
cuda.memcpy_dtoh_async(h_output, d_output, stream)
|
269 |
+
|
270 |
+
# 스트림 동기화
|
271 |
+
stream.synchronize()
|
272 |
+
|
273 |
+
# 출력 데이터 복사
|
274 |
+
np.copyto(output, h_output)
|
275 |
+
|
276 |
+
return output
|
277 |
+
|
278 |
+
except Exception as e:
|
279 |
+
print(f"TensorRT 추론 중 오류 발생: {str(e)}")
|
280 |
+
import traceback
|
281 |
+
traceback.print_exc()
|
282 |
+
return None
|
283 |
+
|
284 |
+
finally:
|
285 |
+
# 메모리 해제
|
286 |
+
del stream
|
287 |
+
if 'd_input' in locals():
|
288 |
+
d_input.free()
|
289 |
+
if 'd_output' in locals():
|
290 |
+
d_output.free()
|
291 |
+
|
292 |
+
def sliding_window_predict(self, image: np.ndarray, window_size: Union[int, Tuple[int, int]],
|
293 |
+
stride: Union[int, Tuple[int, int]]) -> np.ndarray:
|
294 |
+
"""
|
295 |
+
슬라이딩 윈도우 방식으로 이미지 예측을 수행합니다. 겹치는 영역은 평균값을 사용합니다.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
image (np.ndarray): 형태가 (1, 3, H, W)인 이미지 배열
|
299 |
+
window_size (int or tuple): 윈도우 크기
|
300 |
+
stride (int or tuple): 윈도우 이동 간격
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
np.ndarray: 예측된 밀도 맵
|
304 |
+
"""
|
305 |
+
# CUDA 초기화 (처음 사용할 때만)
|
306 |
+
global cuda
|
307 |
+
if 'cuda' not in globals():
|
308 |
+
import pycuda.driver as cuda
|
309 |
+
cuda.init()
|
310 |
+
|
311 |
+
# 입력 검증
|
312 |
+
assert len(image.shape) == 4, f"이미지는 4차원 배열이어야 합니다. (1, C, H, W), 현재: {image.shape}"
|
313 |
+
|
314 |
+
# 윈도우 크기와 스트라이드 설정
|
315 |
+
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
|
316 |
+
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
|
317 |
+
window_size = tuple(window_size)
|
318 |
+
stride = tuple(stride)
|
319 |
+
|
320 |
+
# 검증
|
321 |
+
assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, \
|
322 |
+
f"윈도우 크기는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {window_size}"
|
323 |
+
assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, \
|
324 |
+
f"스트라이드는 양수 정수 튜플 (h, w)이어야 합니다. 현재: {stride}"
|
325 |
+
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], \
|
326 |
+
f"스트라이드는 윈도우 크기보다 작아야 합니다. 현재: {stride}와 {window_size}"
|
327 |
+
|
328 |
+
image_height, image_width = image.shape[-2:]
|
329 |
+
window_height, window_width = window_size
|
330 |
+
stride_height, stride_width = stride
|
331 |
+
|
332 |
+
# 슬라이딩 윈도우 수 계산
|
333 |
+
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
|
334 |
+
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
|
335 |
+
|
336 |
+
# 윈도우 추출
|
337 |
+
windows = []
|
338 |
+
window_positions = []
|
339 |
+
for i in range(num_rows):
|
340 |
+
for j in range(num_cols):
|
341 |
+
x_start, y_start = i * stride_height, j * stride_width
|
342 |
+
x_end, y_end = x_start + window_height, y_start + window_width
|
343 |
+
|
344 |
+
# 이미지 경계 처리
|
345 |
+
if x_end > image_height:
|
346 |
+
x_start, x_end = image_height - window_height, image_height
|
347 |
+
if y_end > image_width:
|
348 |
+
y_start, y_end = image_width - window_width, image_width
|
349 |
+
|
350 |
+
window = image[:, :, x_start:x_end, y_start:y_end]
|
351 |
+
windows.append(window)
|
352 |
+
window_positions.append((x_start, y_start, x_end, y_end))
|
353 |
+
|
354 |
+
# 배치 단위로 추론
|
355 |
+
all_preds = []
|
356 |
+
max_batch_size = 8
|
357 |
+
|
358 |
+
for start_idx in range(0, len(windows), max_batch_size):
|
359 |
+
end_idx = min(start_idx + max_batch_size, len(windows))
|
360 |
+
batch_windows = np.vstack(windows[start_idx:end_idx]) # (batch_size, 3, h, w)
|
361 |
+
|
362 |
+
# TensorRT 추론
|
363 |
+
batch_preds = self._infer_batch(batch_windows)
|
364 |
+
|
365 |
+
# Debug 정보
|
366 |
+
# print(f"배치 입력 형태: {batch_windows.shape}, 배치 출력 형태: {batch_preds.shape}")
|
367 |
+
|
368 |
+
all_preds.extend([batch_preds[i:i+1] for i in range(batch_preds.shape[0])])
|
369 |
+
|
370 |
+
# 예측 결과를 numpy 배열로 변환
|
371 |
+
preds = np.concatenate(all_preds, axis=0)
|
372 |
+
|
373 |
+
# 출력 밀도 맵 조립
|
374 |
+
pred_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32)
|
375 |
+
count_map = np.zeros((preds.shape[1], image_height // self.reduction, image_width // self.reduction), dtype=np.float32)
|
376 |
+
|
377 |
+
idx = 0
|
378 |
+
for i in range(num_rows):
|
379 |
+
for j in range(num_cols):
|
380 |
+
x_start, y_start, x_end, y_end = window_positions[idx]
|
381 |
+
|
382 |
+
# 출력 영역 계산 (reduction 고려)
|
383 |
+
x_start_out = x_start // self.reduction
|
384 |
+
y_start_out = y_start // self.reduction
|
385 |
+
x_end_out = x_end // self.reduction
|
386 |
+
y_end_out = y_end // self.reduction
|
387 |
+
|
388 |
+
pred_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += preds[idx]
|
389 |
+
count_map[:, x_start_out:x_end_out, y_start_out:y_end_out] += 1.
|
390 |
+
idx += 1
|
391 |
+
|
392 |
+
# 겹치는 영역 평균 계산
|
393 |
+
pred_map /= count_map
|
394 |
+
|
395 |
+
return pred_map
|
396 |
+
|
397 |
+
def resize_density_map(self, density_map: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
398 |
+
"""
|
399 |
+
밀도 맵의 크기를 조정합니다. 총합은 보존됩니다.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
density_map: 형태가 (C, H, W)인 밀도 맵
|
403 |
+
target_size: 목표 크기 (H', W')
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
np.ndarray: 크기가 조정된 밀도 맵
|
407 |
+
"""
|
408 |
+
from PIL import Image
|
409 |
+
import torch.nn.functional as F
|
410 |
+
import torch
|
411 |
+
|
412 |
+
# numpy를 torch로 변환
|
413 |
+
if isinstance(density_map, np.ndarray):
|
414 |
+
density_map = torch.from_numpy(density_map)
|
415 |
+
|
416 |
+
# 배치 차원 추가
|
417 |
+
if density_map.dim() == 3:
|
418 |
+
density_map = density_map.unsqueeze(0) # (1, C, H, W)
|
419 |
+
|
420 |
+
current_size = density_map.shape[2:]
|
421 |
+
|
422 |
+
if current_size[0] == target_size[0] and current_size[1] == target_size[1]:
|
423 |
+
return density_map.squeeze(0).numpy()
|
424 |
+
|
425 |
+
# 원본 밀도 맵의 총합 계산
|
426 |
+
original_sum = density_map.sum()
|
427 |
+
|
428 |
+
# 크기 조정 (쌍선형 보간)
|
429 |
+
resized_map = F.interpolate(
|
430 |
+
density_map,
|
431 |
+
size=target_size,
|
432 |
+
mode='bilinear',
|
433 |
+
align_corners=False
|
434 |
+
)
|
435 |
+
|
436 |
+
# 총합 보존을 위한 스케일링
|
437 |
+
if resized_map.sum() > 0: # 0으로 나누기 방지
|
438 |
+
resized_map = resized_map * (original_sum / resized_map.sum())
|
439 |
+
|
440 |
+
return resized_map.squeeze(0).numpy()
|
441 |
+
|
442 |
+
def predict(self, image: Union[str, np.ndarray]) -> float:
|
443 |
+
"""
|
444 |
+
이미지에서 군중 계수 예측을 수행합니다.
|
445 |
+
|
446 |
+
Args:
|
447 |
+
image: 입력 이미지 (경로, 넘파이 배열, 또는 업로드된 파일)
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
float: 예측된 사람 수
|
451 |
+
"""
|
452 |
+
# 이미지 전처리
|
453 |
+
processed_image = self._process_image(image)
|
454 |
+
image_height, image_width = processed_image.shape[-2:]
|
455 |
+
|
456 |
+
# 슬라이딩 윈도우 예측
|
457 |
+
pred_density = self.sliding_window_predict(
|
458 |
+
processed_image,
|
459 |
+
self.window_size,
|
460 |
+
self.stride
|
461 |
+
)
|
462 |
+
|
463 |
+
# 예측 결과 저장
|
464 |
+
pred_count = pred_density.sum()
|
465 |
+
|
466 |
+
# 원본 이미지 크기로 밀도 맵 조정
|
467 |
+
resized_pred_density = self.resize_density_map(
|
468 |
+
pred_density,
|
469 |
+
(image_height, image_width)
|
470 |
+
)
|
471 |
+
|
472 |
+
# 결과 저장
|
473 |
+
self.processed_image = self._post_process_image(processed_image)
|
474 |
+
self.density_map = resized_pred_density.squeeze()
|
475 |
+
self.count = pred_count
|
476 |
+
|
477 |
+
return pred_count
|
478 |
+
|
479 |
+
def visualize_density_map(self, alpha: float = 0.5, save: bool = False,
|
480 |
+
save_path: Optional[str] = None):
|
481 |
+
"""
|
482 |
+
현재 저장된 예측 결과를 시각화합니다.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
alpha (float): density map의 투명도 (0~1). 기본값 0.5
|
486 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
487 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
488 |
+
기본값 None
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
Tuple[matplotlib.figure.Figure, np.ndarray]:
|
492 |
+
- density map이 오버레이된 matplotlib Figure 객체
|
493 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
494 |
+
"""
|
495 |
+
if self.density_map is None or self.processed_image is None:
|
496 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
497 |
+
|
498 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
499 |
+
ax.imshow(self.processed_image)
|
500 |
+
ax.imshow(self.density_map, cmap="jet", alpha=alpha)
|
501 |
+
ax.axis("off")
|
502 |
+
plt.title(f"Count: {self.count:.1f}")
|
503 |
+
|
504 |
+
if save:
|
505 |
+
if save_path is None:
|
506 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
507 |
+
save_path = f"crowd_density_{timestamp}.png"
|
508 |
+
|
509 |
+
# 여백 제거하고 저장
|
510 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
511 |
+
print(f"이미지 저장 완료: {save_path}")
|
512 |
+
|
513 |
+
fig.canvas.draw()
|
514 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
515 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
516 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
517 |
+
|
518 |
+
return fig, image_from_plot
|
519 |
+
|
520 |
+
def visualize_dots(self, dot_size: int = 20, sigma: float = 1, percentile: float = 97,
|
521 |
+
save: bool = False, save_path: Optional[str] = None):
|
522 |
+
"""
|
523 |
+
예측된 군중 위치를 점으로 표시하여 시각화합니다.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
dot_size (int): 점의 크기. 기본값 20
|
527 |
+
sigma (float): Gaussian 필터의 sigma 값. 기본값 1
|
528 |
+
percentile (float): 임계값으로 사용할 백분위수 (0-100). 기본값 97
|
529 |
+
save (bool): 시각화 결과를 이미지로 저장할지 여부. 기본값 False
|
530 |
+
save_path (str, optional): 저장할 경로. None일 경우 현재 디렉토리에 자동 생성된 이름으로 저장.
|
531 |
+
기본값 None
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
Tuple[matplotlib.backends.backend_agg.FigureCanvasBase, np.ndarray]:
|
535 |
+
- matplotlib figure의 canvas 객체
|
536 |
+
- RGB 형식의 시각화된 이미지 배열 (H, W, 3)
|
537 |
+
"""
|
538 |
+
if self.density_map is None or self.processed_image is None:
|
539 |
+
raise ValueError("먼저 predict 메서드를 실행하여 예측을 수행해야 합니다.")
|
540 |
+
|
541 |
+
adjusted_pred_count = int(round(self.count))
|
542 |
+
|
543 |
+
fig, ax = plt.subplots(dpi=200, frameon=False)
|
544 |
+
ax.imshow(self.processed_image)
|
545 |
+
|
546 |
+
filtered_density = gaussian_filter(self.density_map, sigma=sigma)
|
547 |
+
|
548 |
+
threshold = np.percentile(filtered_density, percentile)
|
549 |
+
candidate_pixels = np.column_stack(np.where(filtered_density >= threshold))
|
550 |
+
|
551 |
+
if len(candidate_pixels) > adjusted_pred_count:
|
552 |
+
kmeans = KMeans(n_clusters=adjusted_pred_count, random_state=42, n_init=10)
|
553 |
+
kmeans.fit(candidate_pixels)
|
554 |
+
head_positions = kmeans.cluster_centers_.astype(int)
|
555 |
+
else:
|
556 |
+
head_positions = candidate_pixels
|
557 |
+
|
558 |
+
y_coords, x_coords = head_positions[:, 0], head_positions[:, 1]
|
559 |
+
ax.scatter(x_coords, y_coords,
|
560 |
+
c='red',
|
561 |
+
s=dot_size,
|
562 |
+
alpha=1.0,
|
563 |
+
edgecolors='white',
|
564 |
+
linewidth=1)
|
565 |
+
|
566 |
+
ax.axis("off")
|
567 |
+
plt.title(f"Count: {self.count:.1f}")
|
568 |
+
|
569 |
+
if save:
|
570 |
+
if save_path is None:
|
571 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
572 |
+
save_path = f"crowd_dots_{timestamp}.png"
|
573 |
+
|
574 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
|
575 |
+
print(f"이미지 저장 완료: {save_path}")
|
576 |
+
|
577 |
+
# Figure를 numpy 배열로 변환
|
578 |
+
fig.canvas.draw()
|
579 |
+
image_from_plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
580 |
+
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
581 |
+
image_from_plot = image_from_plot[:,:,:3] # RGB로 변환
|
582 |
+
|
583 |
+
return fig.canvas, image_from_plot
|
584 |
+
|
585 |
+
def crowd_count(self):
|
586 |
+
"""
|
587 |
+
가장 최근 예측의 군중 수를 반환합니다.
|
588 |
+
|
589 |
+
Returns:
|
590 |
+
float: 예측된 군중 수
|
591 |
+
None: 아직 예측이 수행되지 않은 경우
|
592 |
+
"""
|
593 |
+
return self.count
|
594 |
+
|
595 |
+
def get_density_map(self):
|
596 |
+
"""
|
597 |
+
가장 최근 예측의 밀도 맵을 반환합니다.
|
598 |
+
|
599 |
+
Returns:
|
600 |
+
numpy.ndarray: 밀도 맵
|
601 |
+
None: 아직 예측이 수행되지 않은 경우
|
602 |
+
"""
|
603 |
+
return self.density_map
|
custom/init_get_model.py
ADDED
File without changes
|
custom/json2seg.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
def get_segmentation_by_id(target_id, json_file="/home/jungseoik/data/PR/CLIP-EBC/assets/seg.json" ):
|
4 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
5 |
+
data = json.load(f)
|
6 |
+
|
7 |
+
# annotations 리스트 가져오기
|
8 |
+
annotations = data.get("annotations", [])
|
9 |
+
|
10 |
+
# annotations 순회하면서 id가 target_id인 항목 찾기
|
11 |
+
for ann in annotations:
|
12 |
+
if ann.get("id") == target_id:
|
13 |
+
return ann.get("segmentation", None)
|
14 |
+
|
15 |
+
# 해당 id가 없으면 None 반환
|
16 |
+
return None
|
custom/mock_gen.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
def create_mock_data_heatmap():
|
5 |
+
# 기본 구조 생성
|
6 |
+
sections = [f'구역 {i}' for i in range(1, 7)]
|
7 |
+
months = list(range(1, 13))
|
8 |
+
years = list(range(2020, 2025))
|
9 |
+
|
10 |
+
# 데이터프레임용 리스트 생성
|
11 |
+
data = []
|
12 |
+
for year in years:
|
13 |
+
for section in sections:
|
14 |
+
for month in months:
|
15 |
+
data.append({
|
16 |
+
'section': section,
|
17 |
+
'month': month,
|
18 |
+
'year': year,
|
19 |
+
'crowd_count': np.random.randint(30000, 500000)
|
20 |
+
})
|
21 |
+
|
22 |
+
# DataFrame 생성
|
23 |
+
df = pd.DataFrame(data)
|
24 |
+
return df
|
25 |
+
|
26 |
+
def create_mock_data_table():
|
27 |
+
mock_data = {
|
28 |
+
'section': [f'구역 {i}' for i in range(1, 7)],
|
29 |
+
'count': np.random.randint(10000, 300000, 6)
|
30 |
+
}
|
31 |
+
|
32 |
+
df = pd.DataFrame(mock_data)
|
33 |
+
return df
|
34 |
+
|
35 |
+
def create_mock_data_donut(min_value=10000, max_value=500000):
|
36 |
+
"""
|
37 |
+
가상의 인구 이동 데이터를 생성합니다.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
tuple: (인바운드 이동 비율, 아웃바운드 이동 비율)
|
41 |
+
"""
|
42 |
+
# 랜덤 값 생성 (10000~500000 사이)
|
43 |
+
inbound = random.randint(min_value, max_value)
|
44 |
+
outbound = random.randint(min_value, max_value)
|
45 |
+
|
46 |
+
# 전체 값 대비 비율 계산 (0-100 사이의 값으로 변환)
|
47 |
+
total = inbound + outbound
|
48 |
+
inbound_percent = round((inbound / total) * 100)
|
49 |
+
outbound_percent = round((outbound / total) * 100)
|
50 |
+
|
51 |
+
return inbound_percent, outbound_percent
|
52 |
+
|
53 |
+
|
54 |
+
def create_mock_data_inout():
|
55 |
+
"""
|
56 |
+
방문객 데이터 랜덤 생성
|
57 |
+
- 이번달 방문객: 150,000 ~ 500,000
|
58 |
+
- 오늘 방문객: 5,000 ~ 100,000
|
59 |
+
- delta는 전월/전일 대비 증감량 (-30% ~ +30%)
|
60 |
+
"""
|
61 |
+
# 이번달 방문객 (더 큰 범위)
|
62 |
+
monthly_visitors = random.randint(150000, 500000)
|
63 |
+
monthly_delta = int(monthly_visitors * random.uniform(-0.3, 0.3)) # 30% 범위 내 증감
|
64 |
+
|
65 |
+
# 오늘 방문객 (더 작은 범위)
|
66 |
+
daily_visitors = random.randint(5000, 100000)
|
67 |
+
daily_delta = int(daily_visitors * random.uniform(-0.3, 0.3)) # 30% 범위 내 증감
|
68 |
+
|
69 |
+
return {
|
70 |
+
'top': {
|
71 |
+
'state': '이번달 방문객',
|
72 |
+
'visitor': monthly_visitors,
|
73 |
+
'delta': monthly_delta
|
74 |
+
},
|
75 |
+
'bottom': {
|
76 |
+
'state': '오늘 방문객',
|
77 |
+
'visitor': daily_visitors,
|
78 |
+
'delta': daily_delta
|
79 |
+
}
|
80 |
+
}
|
custom/visual.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import altair as alt
|
2 |
+
import pandas as pd
|
3 |
+
from typing import Tuple, Literal, Union
|
4 |
+
# Heatmap
|
5 |
+
def make_heatmap(input_df, input_y, input_x, input_color, input_color_theme):
|
6 |
+
heatmap = alt.Chart(input_df).mark_rect().encode(
|
7 |
+
y=alt.Y(f'{input_y}:O', axis=alt.Axis(title="Month", titleFontSize=18, titlePadding=15, titleFontWeight=900, labelAngle=0)),
|
8 |
+
x=alt.X(f'{input_x}:O', axis=alt.Axis(title="", titleFontSize=18, titlePadding=15, titleFontWeight=900, labelAngle=0)),
|
9 |
+
color=alt.Color(f'max({input_color}):Q',
|
10 |
+
legend=None,
|
11 |
+
scale=alt.Scale(scheme=input_color_theme)),
|
12 |
+
stroke=alt.value('black'),
|
13 |
+
strokeWidth=alt.value(0.25),
|
14 |
+
).properties(width=900
|
15 |
+
).configure_axis(
|
16 |
+
labelFontSize=12,
|
17 |
+
titleFontSize=12
|
18 |
+
)
|
19 |
+
# height=300
|
20 |
+
return heatmap
|
21 |
+
|
22 |
+
|
23 |
+
# Donut chart
|
24 |
+
def make_donut(
|
25 |
+
input_response: float,
|
26 |
+
input_text: str,
|
27 |
+
input_color: Literal['blue', 'green', 'orange', 'red']
|
28 |
+
) -> alt.LayerChart:
|
29 |
+
"""
|
30 |
+
Altair를 사용하여 지정된 퍼센트, 레이블, 색상 스키마로 도넛 차트를 생성합니다.
|
31 |
+
|
32 |
+
함수 구조:
|
33 |
+
1. 입력 색상에 따른 색상 스키마 정의
|
34 |
+
2. 두 개의 DataFrame 생성:
|
35 |
+
- 퍼센트 표시를 위한 메인 데이터
|
36 |
+
- 전체 원을 위한 배경 데이터
|
37 |
+
3. 세 개의 레이어 생성:
|
38 |
+
- 배경 원 (plot_bg)
|
39 |
+
- 퍼센트 호 (plot)
|
40 |
+
- 중앙 텍스트 표시
|
41 |
+
|
42 |
+
매개변수:
|
43 |
+
----------
|
44 |
+
input_response : float
|
45 |
+
표시할 퍼센트 값 (0-100 사이)
|
46 |
+
input_text : str
|
47 |
+
차트에 표시할 레이블 텍스트
|
48 |
+
input_color : str
|
49 |
+
사용할 색상 스키마 ('blue', 'green', 'orange', 'red' 중 하나)
|
50 |
+
|
51 |
+
반환값:
|
52 |
+
-------
|
53 |
+
alt.LayerChart
|
54 |
+
배경, 퍼센트 호, 중앙 텍스트가 결합된 Altair 레이어 차트
|
55 |
+
|
56 |
+
사용 예시:
|
57 |
+
---------
|
58 |
+
>>> chart = make_donut(75, "완료", "blue")
|
59 |
+
>>> chart.save('donut.html')
|
60 |
+
"""
|
61 |
+
if input_color == 'blue':
|
62 |
+
chart_color = ['#29b5e8', '#155F7A']
|
63 |
+
if input_color == 'green':
|
64 |
+
chart_color = ['#27AE60', '#12783D']
|
65 |
+
if input_color == 'orange':
|
66 |
+
chart_color = ['#F39C12', '#875A12']
|
67 |
+
if input_color == 'red':
|
68 |
+
chart_color = ['#E74C3C', '#781F16']
|
69 |
+
|
70 |
+
source = pd.DataFrame({
|
71 |
+
"Topic": ['', input_text],
|
72 |
+
"% value": [100-input_response, input_response]
|
73 |
+
})
|
74 |
+
source_bg = pd.DataFrame({
|
75 |
+
"Topic": ['', input_text],
|
76 |
+
"% value": [100, 0]
|
77 |
+
})
|
78 |
+
|
79 |
+
plot = alt.Chart(source).mark_arc(innerRadius=45, cornerRadius=25).encode(
|
80 |
+
theta="% value",
|
81 |
+
color= alt.Color("Topic:N",
|
82 |
+
scale=alt.Scale(
|
83 |
+
#domain=['A', 'B'],
|
84 |
+
domain=[input_text, ''],
|
85 |
+
# range=['#29b5e8', '#155F7A']), # 31333F
|
86 |
+
range=chart_color),
|
87 |
+
legend=None),
|
88 |
+
).properties(width=130, height=130)
|
89 |
+
|
90 |
+
text = plot.mark_text(align='center', color="#29b5e8", font="Lato", fontSize=32, fontWeight=700, fontStyle="italic").encode(text=alt.value(f'{input_response} %'))
|
91 |
+
plot_bg = alt.Chart(source_bg).mark_arc(innerRadius=45, cornerRadius=20).encode(
|
92 |
+
theta="% value",
|
93 |
+
color= alt.Color("Topic:N",
|
94 |
+
scale=alt.Scale(
|
95 |
+
# domain=['A', 'B'],
|
96 |
+
domain=[input_text, ''],
|
97 |
+
range=chart_color), # 31333F
|
98 |
+
legend=None),
|
99 |
+
).properties(width=130, height=130)
|
100 |
+
return plot_bg + plot + text
|
losses/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dm_loss import DMLoss
|
2 |
+
from .dace_loss import DACELoss
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"DMLoss",
|
6 |
+
"DACELoss",
|
7 |
+
]
|
losses/bregman_pytorch.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code modified from https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/bregman_pytorch.py
|
2 |
+
import torch
|
3 |
+
from torch import Tensor
|
4 |
+
from torch.cuda.amp import autocast
|
5 |
+
from typing import Union, Tuple, Dict
|
6 |
+
|
7 |
+
M_EPS = 1e-16
|
8 |
+
|
9 |
+
|
10 |
+
@autocast(enabled=True, dtype=torch.float32) # avoid numerical instability
|
11 |
+
def sinkhorn(
|
12 |
+
a: Tensor,
|
13 |
+
b: Tensor,
|
14 |
+
C: Tensor,
|
15 |
+
reg: float = 1e-1,
|
16 |
+
maxIter: int = 1000,
|
17 |
+
stopThr: float = 1e-9,
|
18 |
+
verbose: bool = False,
|
19 |
+
log: bool = True,
|
20 |
+
eval_freq: int = 10,
|
21 |
+
print_freq: int = 200,
|
22 |
+
) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]:
|
23 |
+
"""
|
24 |
+
Solve the entropic regularization optimal transport
|
25 |
+
The input should be PyTorch tensors
|
26 |
+
The function solves the following optimization problem:
|
27 |
+
|
28 |
+
.. math::
|
29 |
+
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
|
30 |
+
s.t. \gamma 1 = a
|
31 |
+
\gamma^T 1= b
|
32 |
+
\gamma\geq 0
|
33 |
+
where :
|
34 |
+
- C is the (ns,nt) metric cost matrix
|
35 |
+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
|
36 |
+
- a and b are target and source measures (sum to 1)
|
37 |
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
a : torch.tensor (na,)
|
42 |
+
samples measure in the target domain
|
43 |
+
b : torch.tensor (nb,)
|
44 |
+
samples in the source domain
|
45 |
+
C : torch.tensor (na,nb)
|
46 |
+
loss matrix
|
47 |
+
reg : float
|
48 |
+
Regularization term > 0
|
49 |
+
maxIter : int, optional
|
50 |
+
Max number of iterations
|
51 |
+
stopThr : float, optional
|
52 |
+
Stop threshol on error ( > 0 )
|
53 |
+
verbose : bool, optional
|
54 |
+
Print information along iterations
|
55 |
+
log : bool, optional
|
56 |
+
record log if True
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
gamma : (na x nb) torch.tensor
|
61 |
+
Optimal transportation matrix for the given parameters
|
62 |
+
log : dict
|
63 |
+
log dictionary return only if log==True in parameters
|
64 |
+
|
65 |
+
References
|
66 |
+
----------
|
67 |
+
[1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
|
68 |
+
See Also
|
69 |
+
--------
|
70 |
+
"""
|
71 |
+
|
72 |
+
device = a.device
|
73 |
+
na, nb = C.shape
|
74 |
+
|
75 |
+
# a = a.view(-1, 1)
|
76 |
+
# b = b.view(-1, 1)
|
77 |
+
|
78 |
+
assert na >= 1 and nb >= 1, f"C needs to be 2d. Found C.shape = {C.shape}"
|
79 |
+
assert na == a.shape[0] and nb == b.shape[0], f"Shape of a ({a.shape}) or b ({b.shape}) does not match that of C ({C.shape})"
|
80 |
+
assert reg > 0, f"reg should be greater than 0. Found reg = {reg}"
|
81 |
+
assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}"
|
82 |
+
|
83 |
+
if log:
|
84 |
+
log = {"err": []}
|
85 |
+
|
86 |
+
u = torch.ones((na), dtype=a.dtype).to(device) / na
|
87 |
+
v = torch.ones((nb), dtype=b.dtype).to(device) / nb
|
88 |
+
|
89 |
+
K = torch.empty(C.shape, dtype=C.dtype).to(device)
|
90 |
+
torch.div(C, -reg, out=K)
|
91 |
+
torch.exp(K, out=K)
|
92 |
+
|
93 |
+
b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
|
94 |
+
|
95 |
+
it = 1
|
96 |
+
err = 1
|
97 |
+
|
98 |
+
# allocate memory beforehand
|
99 |
+
KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
|
100 |
+
Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
|
101 |
+
|
102 |
+
while (err > stopThr and it <= maxIter):
|
103 |
+
upre, vpre = u, v
|
104 |
+
# torch.matmul(u, K, out=KTu)
|
105 |
+
KTu = torch.matmul(u.view(1, -1), K).view(-1)
|
106 |
+
v = torch.div(b, KTu + M_EPS)
|
107 |
+
# torch.matmul(K, v, out=Kv)
|
108 |
+
Kv = torch.matmul(K, v.view(-1, 1)).view(-1)
|
109 |
+
u = torch.div(a, Kv + M_EPS)
|
110 |
+
|
111 |
+
if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
|
112 |
+
torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
|
113 |
+
print("Warning: numerical errors at iteration", it)
|
114 |
+
u, v = upre, vpre
|
115 |
+
break
|
116 |
+
|
117 |
+
if log and it % eval_freq == 0:
|
118 |
+
# we can speed up the process by checking for the error only all
|
119 |
+
# the eval_freq iterations
|
120 |
+
# below is equivalent to:
|
121 |
+
# b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0)
|
122 |
+
# but with more memory efficient
|
123 |
+
b_hat = (torch.matmul(u.view(1, -1), K) * v.view(1, -1)).view(-1)
|
124 |
+
err = (b - b_hat).pow(2).sum().item()
|
125 |
+
# err = (b - b_hat).abs().sum().item()
|
126 |
+
log["err"].append(err)
|
127 |
+
|
128 |
+
if verbose and it % print_freq == 0:
|
129 |
+
print("iteration {:5d}, constraint error {:5e}".format(it, err))
|
130 |
+
|
131 |
+
it += 1
|
132 |
+
|
133 |
+
if log:
|
134 |
+
log["u"] = u
|
135 |
+
log["v"] = v
|
136 |
+
log["alpha"] = reg * torch.log(u + M_EPS)
|
137 |
+
log["beta"] = reg * torch.log(v + M_EPS)
|
138 |
+
|
139 |
+
# transport plan
|
140 |
+
P = u.reshape(-1, 1) * K * v.reshape(1, -1)
|
141 |
+
if log:
|
142 |
+
return P, log
|
143 |
+
else:
|
144 |
+
return P
|
losses/dace_loss.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from typing import Any, List, Tuple, Dict
|
4 |
+
|
5 |
+
from .dm_loss import DMLoss
|
6 |
+
from .utils import _reshape_density
|
7 |
+
|
8 |
+
|
9 |
+
class DACELoss(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
bins: List[Tuple[float, float]],
|
13 |
+
reduction: int,
|
14 |
+
weight_count_loss: float = 1.0,
|
15 |
+
count_loss: str = "mae",
|
16 |
+
**kwargs: Any
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
assert len(bins) > 0, f"Expected at least one bin, got {bins}"
|
20 |
+
assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
|
21 |
+
assert all([b[0] <= b[1] for b in bins]), f"Expected all bins to be in increasing order, got {bins}"
|
22 |
+
self.bins = bins
|
23 |
+
self.reduction = reduction
|
24 |
+
self.cross_entropy_fn = nn.CrossEntropyLoss(reduction="none")
|
25 |
+
|
26 |
+
count_loss = count_loss.lower()
|
27 |
+
assert count_loss in ["mae", "mse", "dmcount"], f"Expected count_loss to be one of ['mae', 'mse', 'dmcount'], got {count_loss}"
|
28 |
+
self.count_loss = count_loss
|
29 |
+
if self.count_loss == "mae":
|
30 |
+
self.use_dm_loss = False
|
31 |
+
self.count_loss_fn = nn.L1Loss(reduction="none")
|
32 |
+
elif self.count_loss == "mse":
|
33 |
+
self.use_dm_loss = False
|
34 |
+
self.count_loss_fn = nn.MSELoss(reduction="none")
|
35 |
+
else:
|
36 |
+
self.use_dm_loss = True
|
37 |
+
assert "input_size" in kwargs, f"Expected input_size to be in kwargs when count_loss='dmcount', got {kwargs}"
|
38 |
+
self.count_loss_fn = DMLoss(reduction=reduction, **kwargs)
|
39 |
+
|
40 |
+
self.weight_count_loss = weight_count_loss
|
41 |
+
|
42 |
+
def _bin_count(self, density_map: Tensor) -> Tensor:
|
43 |
+
class_map = torch.zeros_like(density_map, dtype=torch.long)
|
44 |
+
for idx, (low, high) in enumerate(self.bins):
|
45 |
+
mask = (density_map >= low) & (density_map <= high)
|
46 |
+
class_map[mask] = idx
|
47 |
+
return class_map.squeeze(1) # remove channel dimension
|
48 |
+
|
49 |
+
def forward(self, pred_class: Tensor, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
50 |
+
target_density = _reshape_density(target_density, reduction=self.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density
|
51 |
+
assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}"
|
52 |
+
|
53 |
+
target_class = self._bin_count(target_density)
|
54 |
+
|
55 |
+
cross_entropy_loss = self.cross_entropy_fn(pred_class, target_class).sum(dim=(-1, -2)).mean()
|
56 |
+
|
57 |
+
if self.use_dm_loss:
|
58 |
+
count_loss, loss_info = self.count_loss_fn(pred_density, target_density, target_points)
|
59 |
+
loss_info["ce_loss"] = cross_entropy_loss.detach()
|
60 |
+
else:
|
61 |
+
count_loss = self.count_loss_fn(pred_density, target_density).sum(dim=(-1, -2, -3)).mean()
|
62 |
+
loss_info = {
|
63 |
+
"ce_loss": cross_entropy_loss.detach(),
|
64 |
+
f"{self.count_loss}_loss": count_loss.detach(),
|
65 |
+
}
|
66 |
+
|
67 |
+
loss = cross_entropy_loss + self.weight_count_loss * count_loss
|
68 |
+
loss_info["loss"] = loss.detach()
|
69 |
+
|
70 |
+
return loss, loss_info
|
losses/dm_loss.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from torch.cuda.amp import autocast
|
4 |
+
from typing import List, Any, Tuple, Dict
|
5 |
+
|
6 |
+
from .bregman_pytorch import sinkhorn
|
7 |
+
from .utils import _reshape_density
|
8 |
+
|
9 |
+
EPS = 1e-8
|
10 |
+
|
11 |
+
|
12 |
+
class OTLoss(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
input_size: int,
|
16 |
+
reduction: int,
|
17 |
+
norm_cood: bool,
|
18 |
+
num_of_iter_in_ot: int = 100,
|
19 |
+
reg: float = 10.0
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
assert input_size % reduction == 0
|
23 |
+
|
24 |
+
self.input_size = input_size
|
25 |
+
self.reduction = reduction
|
26 |
+
self.norm_cood = norm_cood
|
27 |
+
self.num_of_iter_in_ot = num_of_iter_in_ot
|
28 |
+
self.reg = reg
|
29 |
+
|
30 |
+
# coordinate is same to image space, set to constant since crop size is same
|
31 |
+
self.cood = torch.arange(0, input_size, step=reduction, dtype=torch.float32) + reduction / 2
|
32 |
+
self.density_size = self.cood.size(0)
|
33 |
+
self.cood.unsqueeze_(0) # [1, #cood]
|
34 |
+
self.cood = self.cood / input_size * 2 - 1 if self.norm_cood else self.cood
|
35 |
+
self.output_size = self.cood.size(1)
|
36 |
+
|
37 |
+
@autocast(enabled=True, dtype=torch.float32) # avoid numerical instability
|
38 |
+
def forward(self, pred_density: Tensor, normed_pred_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, float, Tensor]:
|
39 |
+
batch_size = normed_pred_density.size(0)
|
40 |
+
assert len(target_points) == batch_size, f"Expected target_points to have length {batch_size}, but got {len(target_points)}"
|
41 |
+
assert self.output_size == normed_pred_density.size(2)
|
42 |
+
device = pred_density.device
|
43 |
+
|
44 |
+
loss = torch.zeros([1]).to(device)
|
45 |
+
ot_obj_values = torch.zeros([1]).to(device)
|
46 |
+
wd = 0 # Wasserstein distance
|
47 |
+
cood = self.cood.to(device)
|
48 |
+
for idx, points in enumerate(target_points):
|
49 |
+
if len(points) > 0:
|
50 |
+
# compute l2 square distance, it should be source target distance. [#gt, #cood * #cood]
|
51 |
+
points = points / self.input_size * 2 - 1 if self.norm_cood else points
|
52 |
+
x = points[:, 0].unsqueeze_(1) # [#gt, 1]
|
53 |
+
y = points[:, 1].unsqueeze_(1)
|
54 |
+
x_dist = -2 * torch.matmul(x, cood) + x * x + cood * cood # [#gt, #cood]
|
55 |
+
y_dist = -2 * torch.matmul(y, cood) + y * y + cood * cood
|
56 |
+
y_dist.unsqueeze_(2)
|
57 |
+
x_dist.unsqueeze_(1)
|
58 |
+
dist = y_dist + x_dist
|
59 |
+
dist = dist.view((dist.size(0), -1)) # size of [#gt, #cood * #cood]
|
60 |
+
|
61 |
+
source_prob = normed_pred_density[idx][0].view([-1]).detach()
|
62 |
+
target_prob = (torch.ones([len(points)]) / len(points)).to(device)
|
63 |
+
# use sinkhorn to solve OT, compute optimal beta.
|
64 |
+
P, log = sinkhorn(target_prob, source_prob, dist, self.reg, maxIter=self.num_of_iter_in_ot, log=True)
|
65 |
+
beta = log["beta"] # size is the same as source_prob: [#cood * #cood]
|
66 |
+
ot_obj_values += torch.sum(normed_pred_density[idx] * beta.view([1, self.output_size, self.output_size]))
|
67 |
+
# compute the gradient of OT loss to predicted density (pred_density).
|
68 |
+
# im_grad = beta / source_count - < beta, source_density> / (source_count)^2
|
69 |
+
source_density = pred_density[idx][0].view([-1]).detach()
|
70 |
+
source_count = source_density.sum()
|
71 |
+
gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#cood * #cood]
|
72 |
+
gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1
|
73 |
+
gradient = gradient_1 - gradient_2
|
74 |
+
gradient = gradient.detach().view([1, self.output_size, self.output_size])
|
75 |
+
# Define loss = <im_grad, predicted density>. The gradient of loss w.r.t predicted density is im_grad.
|
76 |
+
loss += torch.sum(pred_density[idx] * gradient)
|
77 |
+
wd += torch.sum(dist * P).item()
|
78 |
+
|
79 |
+
return loss, wd, ot_obj_values
|
80 |
+
|
81 |
+
|
82 |
+
class DMLoss(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
input_size: int,
|
86 |
+
reduction: int,
|
87 |
+
norm_cood: bool = False,
|
88 |
+
weight_ot: float = 0.1,
|
89 |
+
weight_tv: float = 0.01,
|
90 |
+
**kwargs: Any
|
91 |
+
) -> None:
|
92 |
+
super().__init__()
|
93 |
+
self.ot_loss = OTLoss(input_size, reduction, norm_cood, **kwargs)
|
94 |
+
self.tv_loss = nn.L1Loss(reduction="none")
|
95 |
+
self.count_loss = nn.L1Loss(reduction="mean")
|
96 |
+
self.weight_ot = weight_ot
|
97 |
+
self.weight_tv = weight_tv
|
98 |
+
|
99 |
+
@autocast(enabled=True, dtype=torch.float32) # avoid numerical instability
|
100 |
+
def forward(self, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
101 |
+
target_density = _reshape_density(target_density, reduction=self.ot_loss.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density
|
102 |
+
assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}"
|
103 |
+
|
104 |
+
pred_count = pred_density.view(pred_density.shape[0], -1).sum(dim=1)
|
105 |
+
normed_pred_density = pred_density / (pred_count.view(-1, 1, 1, 1) + EPS)
|
106 |
+
target_count = torch.tensor([len(p) for p in target_points], dtype=torch.float32).to(target_density.device)
|
107 |
+
normed_target_density = target_density / (target_count.view(-1, 1, 1, 1) + EPS)
|
108 |
+
|
109 |
+
ot_loss, _, _ = self.ot_loss(pred_density, normed_pred_density, target_points)
|
110 |
+
|
111 |
+
tv_loss = (self.tv_loss(normed_pred_density, normed_target_density).sum(dim=(1, 2, 3)) * target_count).mean()
|
112 |
+
|
113 |
+
count_loss = self.count_loss(pred_count, target_count)
|
114 |
+
|
115 |
+
loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + count_loss
|
116 |
+
|
117 |
+
loss_info = {
|
118 |
+
"loss": loss.detach(),
|
119 |
+
"ot_loss": ot_loss.detach(),
|
120 |
+
"tv_loss": tv_loss.detach(),
|
121 |
+
"count_loss": count_loss.detach(),
|
122 |
+
}
|
123 |
+
|
124 |
+
return loss, loss_info
|
losses/utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
|
4 |
+
def _reshape_density(density: Tensor, reduction: int) -> Tensor:
|
5 |
+
assert len(density.shape) == 4, f"Expected 4D (B, 1, H, W) tensor, got {density.shape}"
|
6 |
+
assert density.shape[1] == 1, f"Expected 1 channel, got {density.shape[1]}"
|
7 |
+
assert density.shape[2] % reduction == 0, f"Expected height to be divisible by {reduction}, got {density.shape[2]}"
|
8 |
+
assert density.shape[3] % reduction == 0, f"Expected width to be divisible by {reduction}, got {density.shape[3]}"
|
9 |
+
return density.reshape(density.shape[0], 1, density.shape[2] // reduction, reduction, density.shape[3] // reduction, reduction).sum(dim=(-1, -3))
|
main.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from custom.clip_ebc_onnx import ClipEBCOnnx
|
4 |
+
|
5 |
+
def parse_args():
|
6 |
+
parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)')
|
7 |
+
parser.add_argument('--image', required=True, help='Path to input image')
|
8 |
+
parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model')
|
9 |
+
parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'],
|
10 |
+
default='none', help='Visualization type')
|
11 |
+
parser.add_argument('--save', action='store_true',
|
12 |
+
help='Save visualization results')
|
13 |
+
parser.add_argument('--output-dir', default='results',
|
14 |
+
help='Directory to save results')
|
15 |
+
|
16 |
+
# 시각화 관련 매개변수
|
17 |
+
parser.add_argument('--alpha', type=float, default=0.5,
|
18 |
+
help='Alpha value for density map')
|
19 |
+
parser.add_argument('--dot-size', type=int, default=20,
|
20 |
+
help='Dot size for dot visualization')
|
21 |
+
parser.add_argument('--sigma', type=float, default=1,
|
22 |
+
help='Sigma value for Gaussian filter')
|
23 |
+
parser.add_argument('--percentile', type=float, default=97,
|
24 |
+
help='Percentile threshold for dot visualization')
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
return parser.parse_args()
|
29 |
+
|
30 |
+
def main():
|
31 |
+
args = parse_args()
|
32 |
+
|
33 |
+
# 모델 초기화 - ONNX 버전
|
34 |
+
model = ClipEBCOnnx(
|
35 |
+
onnx_model_path=args.model
|
36 |
+
)
|
37 |
+
|
38 |
+
# 출력 디렉토리 생성
|
39 |
+
if args.save:
|
40 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
41 |
+
|
42 |
+
# 예측 수행
|
43 |
+
count = model.predict(args.image)
|
44 |
+
print(f"예측된 군중 수: {count:.2f}")
|
45 |
+
|
46 |
+
# 시각화
|
47 |
+
if args.visualize in ['density', 'all']:
|
48 |
+
save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None
|
49 |
+
fig, density_map = model.visualize_density_map(
|
50 |
+
alpha=args.alpha,
|
51 |
+
save=args.save,
|
52 |
+
save_path=save_path
|
53 |
+
)
|
54 |
+
|
55 |
+
if args.visualize in ['dots', 'all']:
|
56 |
+
save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None
|
57 |
+
canvas, dot_map = model.visualize_dots(
|
58 |
+
dot_size=args.dot_size,
|
59 |
+
sigma=args.sigma,
|
60 |
+
percentile=args.percentile,
|
61 |
+
save=args.save,
|
62 |
+
save_path=save_path
|
63 |
+
)
|
64 |
+
|
65 |
+
# matplotlib figure 닫기 (메모리 누수 방지)
|
66 |
+
if args.visualize in ['density', 'all']:
|
67 |
+
import matplotlib.pyplot as plt
|
68 |
+
plt.close(fig)
|
69 |
+
|
70 |
+
if args.visualize in ['dots', 'all']:
|
71 |
+
import matplotlib.pyplot as plt
|
72 |
+
plt.close(canvas.figure)
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Any, Union
|
2 |
+
|
3 |
+
from .model import _classifier, _regressor, Classifier, Regressor
|
4 |
+
from .clip import _clip_ebc, CLIP_EBC
|
5 |
+
import assets
|
6 |
+
|
7 |
+
clip_names = ["resnet50", "resnet50x4", "resnet50x16", "resnet50x64", "resnet101", "vit_b_16", "vit_b_32", "vit_l_14"]
|
8 |
+
|
9 |
+
|
10 |
+
def get_model(
|
11 |
+
backbone: str,
|
12 |
+
input_size: int,
|
13 |
+
reduction: int,
|
14 |
+
bins: Optional[List[Tuple[float, float]]] = None,
|
15 |
+
anchor_points: Optional[List[float]] = None,
|
16 |
+
**kwargs: Any,
|
17 |
+
) -> Union[Regressor, Classifier, CLIP_EBC]:
|
18 |
+
backbone = backbone.lower()
|
19 |
+
if "clip" in backbone:
|
20 |
+
backbone = backbone[5:]
|
21 |
+
assert backbone in clip_names, f"Expected backbone to be in {clip_names}, got {backbone}"
|
22 |
+
return _clip_ebc(
|
23 |
+
backbone=backbone,
|
24 |
+
input_size=input_size,
|
25 |
+
reduction=reduction,
|
26 |
+
bins=bins,
|
27 |
+
anchor_points=anchor_points,
|
28 |
+
**kwargs
|
29 |
+
)
|
30 |
+
elif bins is None and anchor_points is None:
|
31 |
+
return _regressor(
|
32 |
+
backbone=backbone,
|
33 |
+
input_size=input_size,
|
34 |
+
reduction=reduction,
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
assert bins is not None and anchor_points is not None, f"Expected bins and anchor_points to be both None or not None, got {bins} and {anchor_points}"
|
38 |
+
return _classifier(
|
39 |
+
backbone=backbone,
|
40 |
+
input_size=input_size,
|
41 |
+
reduction=reduction,
|
42 |
+
bins=bins,
|
43 |
+
anchor_points=anchor_points,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
__all__ = [
|
48 |
+
"get_model",
|
49 |
+
]
|
models/clip/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import CLIP_EBC, _clip_ebc
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"CLIP_EBC",
|
6 |
+
"_clip_ebc",
|
7 |
+
]
|
models/clip/_clip/__init__.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from typing import Tuple, Optional, Any, Union
|
4 |
+
import json
|
5 |
+
|
6 |
+
from .utils import tokenize, transform
|
7 |
+
from .prepare import prepare
|
8 |
+
from .text_encoder import CLIPTextEncoder
|
9 |
+
from .image_encoder import ModifiedResNet, VisionTransformer
|
10 |
+
from .model import CLIP
|
11 |
+
|
12 |
+
|
13 |
+
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
14 |
+
|
15 |
+
clip_model_names = [
|
16 |
+
"clip_resnet50",
|
17 |
+
"clip_resnet101",
|
18 |
+
"clip_resnet50x4",
|
19 |
+
"clip_resnet50x16",
|
20 |
+
"clip_resnet50x64",
|
21 |
+
"clip_vit_b_32",
|
22 |
+
"clip_vit_b_16",
|
23 |
+
"clip_vit_l_14",
|
24 |
+
"clip_vit_l_14_336px",
|
25 |
+
]
|
26 |
+
|
27 |
+
clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names]
|
28 |
+
clip_text_encoder_names = [f"clip_text_encoder_{name[5:]}" for name in clip_model_names]
|
29 |
+
|
30 |
+
|
31 |
+
for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names:
|
32 |
+
model_weights_path = os.path.join(curr_dir, "weights", f"{name}.pth")
|
33 |
+
model_config_path = os.path.join(curr_dir, "configs", f"{name}.json")
|
34 |
+
if not os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")) or not os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")):
|
35 |
+
prepare()
|
36 |
+
break
|
37 |
+
|
38 |
+
|
39 |
+
for name in clip_model_names + clip_image_encoder_names + clip_text_encoder_names:
|
40 |
+
assert os.path.exists(os.path.join(curr_dir, "weights", f"{name}.pth")), f"Missing {name}.pth in weights folder. Please run models/clip/prepare.py to download the weights."
|
41 |
+
assert os.path.exists(os.path.join(curr_dir, "configs", f"{name}.json")), f"Missing {name}.json in configs folder. Please run models/clip/prepare.py to download the configs."
|
42 |
+
|
43 |
+
|
44 |
+
def _clip(name: str, input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
45 |
+
with open(os.path.join(curr_dir, "configs", f"clip_{name}.json"), "r") as f:
|
46 |
+
config = json.load(f)
|
47 |
+
|
48 |
+
model = CLIP(
|
49 |
+
embed_dim=config["embed_dim"],
|
50 |
+
# vision
|
51 |
+
image_resolution=config["image_resolution"],
|
52 |
+
vision_layers=config["vision_layers"],
|
53 |
+
vision_width=config["vision_width"],
|
54 |
+
vision_patch_size=config["vision_patch_size"],
|
55 |
+
# text
|
56 |
+
context_length=config["context_length"],
|
57 |
+
vocab_size=config["vocab_size"],
|
58 |
+
transformer_width=config["transformer_width"],
|
59 |
+
transformer_heads=config["transformer_heads"],
|
60 |
+
transformer_layers=config["transformer_layers"]
|
61 |
+
)
|
62 |
+
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_{name}.pth"), map_location="cpu")
|
63 |
+
model.load_state_dict(state_dict, strict=True)
|
64 |
+
|
65 |
+
if input_size is not None:
|
66 |
+
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
|
67 |
+
if name.startswith("vit"):
|
68 |
+
model.visual.adjust_pos_embed(*input_size)
|
69 |
+
|
70 |
+
return model
|
71 |
+
|
72 |
+
|
73 |
+
def _resnet(
|
74 |
+
name: str,
|
75 |
+
reduction: int = 32,
|
76 |
+
features_only: bool = False,
|
77 |
+
out_indices: Optional[Tuple[int, ...]] = None,
|
78 |
+
**kwargs: Any
|
79 |
+
) -> ModifiedResNet:
|
80 |
+
with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f:
|
81 |
+
config = json.load(f)
|
82 |
+
model = ModifiedResNet(
|
83 |
+
layers=config["vision_layers"],
|
84 |
+
output_dim=config["embed_dim"],
|
85 |
+
input_resolution=config["image_resolution"],
|
86 |
+
width=config["vision_width"],
|
87 |
+
heads=config["vision_heads"],
|
88 |
+
features_only=features_only,
|
89 |
+
out_indices=out_indices,
|
90 |
+
reduction=reduction
|
91 |
+
)
|
92 |
+
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu")
|
93 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
94 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
95 |
+
print(f"Missing keys: {missing_keys}")
|
96 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
97 |
+
else:
|
98 |
+
print(f"All keys matched successfully.")
|
99 |
+
|
100 |
+
return model
|
101 |
+
|
102 |
+
|
103 |
+
def _vit(name: str, features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
|
104 |
+
with open(os.path.join(curr_dir, "configs", f"clip_image_encoder_{name}.json"), "r") as f:
|
105 |
+
config = json.load(f)
|
106 |
+
model = VisionTransformer(
|
107 |
+
input_resolution=config["image_resolution"],
|
108 |
+
patch_size=config["vision_patch_size"],
|
109 |
+
output_dim=config["embed_dim"],
|
110 |
+
width=config["vision_width"],
|
111 |
+
layers=config["vision_layers"],
|
112 |
+
heads=config["vision_heads"],
|
113 |
+
features_only=features_only
|
114 |
+
)
|
115 |
+
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_image_encoder_{name}.pth"), map_location="cpu")
|
116 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
117 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
118 |
+
print(f"Missing keys: {missing_keys}")
|
119 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
120 |
+
else:
|
121 |
+
print(f"All keys matched successfully.")
|
122 |
+
|
123 |
+
if input_size is not None:
|
124 |
+
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
|
125 |
+
model.adjust_pos_embed(*input_size)
|
126 |
+
return model
|
127 |
+
|
128 |
+
|
129 |
+
def _text_encoder(name: str) -> CLIPTextEncoder:
|
130 |
+
with open(os.path.join(curr_dir, "configs", f"clip_text_encoder_{name}.json"), "r") as f:
|
131 |
+
config = json.load(f)
|
132 |
+
model = CLIPTextEncoder(
|
133 |
+
embed_dim=config["embed_dim"],
|
134 |
+
context_length=config["context_length"],
|
135 |
+
vocab_size=config["vocab_size"],
|
136 |
+
transformer_width=config["transformer_width"],
|
137 |
+
transformer_heads=config["transformer_heads"],
|
138 |
+
transformer_layers=config["transformer_layers"]
|
139 |
+
)
|
140 |
+
state_dict = torch.load(os.path.join(curr_dir, "weights", f"clip_text_encoder_{name}.pth"), map_location="cpu")
|
141 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
142 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
143 |
+
print(f"Missing keys: {missing_keys}")
|
144 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
145 |
+
else:
|
146 |
+
print(f"All keys matched successfully.")
|
147 |
+
|
148 |
+
return model
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
# CLIP models
|
153 |
+
def resnet50_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
154 |
+
return _clip("resnet50", input_size)
|
155 |
+
|
156 |
+
def resnet101_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
157 |
+
return _clip("resnet101", input_size)
|
158 |
+
|
159 |
+
def resnet50x4_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
160 |
+
return _clip("resnet50x4", input_size)
|
161 |
+
|
162 |
+
def resnet50x16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
163 |
+
return _clip("resnet50x16", input_size)
|
164 |
+
|
165 |
+
def resnet50x64_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
166 |
+
return _clip("resnet50x64", input_size)
|
167 |
+
|
168 |
+
def vit_b_32_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
169 |
+
return _clip("vit_b_32", input_size)
|
170 |
+
|
171 |
+
def vit_b_16_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
172 |
+
return _clip("vit_b_16", input_size)
|
173 |
+
|
174 |
+
def vit_l_14_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
175 |
+
return _clip("vit_l_14", input_size)
|
176 |
+
|
177 |
+
def vit_l_14_336px_clip(input_size: Optional[Union[int, Tuple[int, int]]] = None) -> CLIP:
|
178 |
+
return _clip("vit_l_14_336px", input_size)
|
179 |
+
|
180 |
+
|
181 |
+
# CLIP image encoders
|
182 |
+
def resnet50_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
|
183 |
+
return _resnet("resnet50", features_only=features_only, out_indices=out_indices, **kwargs)
|
184 |
+
|
185 |
+
def resnet101_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
|
186 |
+
return _resnet("resnet101", features_only=features_only, out_indices=out_indices, **kwargs)
|
187 |
+
|
188 |
+
def resnet50x4_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
|
189 |
+
return _resnet("resnet50x4", features_only=features_only, out_indices=out_indices, **kwargs)
|
190 |
+
|
191 |
+
def resnet50x16_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
|
192 |
+
return _resnet("resnet50x16", features_only=features_only, out_indices=out_indices, **kwargs)
|
193 |
+
|
194 |
+
def resnet50x64_img(features_only: bool = False, out_indices: Optional[Tuple[int, ...]] = None, **kwargs: Any) -> ModifiedResNet:
|
195 |
+
return _resnet("resnet50x64", features_only=features_only, out_indices=out_indices, **kwargs)
|
196 |
+
|
197 |
+
def vit_b_32_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
|
198 |
+
return _vit("vit_b_32", features_only=features_only, input_size=input_size, **kwargs)
|
199 |
+
|
200 |
+
def vit_b_16_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
|
201 |
+
return _vit("vit_b_16", features_only=features_only, input_size=input_size, **kwargs)
|
202 |
+
|
203 |
+
def vit_l_14_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
|
204 |
+
return _vit("vit_l_14", features_only=features_only, input_size=input_size, **kwargs)
|
205 |
+
|
206 |
+
def vit_l_14_336px_img(features_only: bool = False, input_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs: Any) -> VisionTransformer:
|
207 |
+
return _vit("vit_l_14_336px", features_only=features_only, input_size=input_size, **kwargs)
|
208 |
+
|
209 |
+
|
210 |
+
# CLIP text encoders
|
211 |
+
def resnet50_txt() -> CLIPTextEncoder:
|
212 |
+
return _text_encoder("resnet50")
|
213 |
+
|
214 |
+
def resnet101_txt() -> CLIPTextEncoder:
|
215 |
+
return _text_encoder("resnet101")
|
216 |
+
|
217 |
+
def resnet50x4_txt() -> CLIPTextEncoder:
|
218 |
+
return _text_encoder("resnet50x4")
|
219 |
+
|
220 |
+
def resnet50x16_txt() -> CLIPTextEncoder:
|
221 |
+
return _text_encoder("resnet50x16")
|
222 |
+
|
223 |
+
def resnet50x64_txt() -> CLIPTextEncoder:
|
224 |
+
return _text_encoder("resnet50x64")
|
225 |
+
|
226 |
+
def vit_b_32_txt() -> CLIPTextEncoder:
|
227 |
+
return _text_encoder("vit_b_32")
|
228 |
+
|
229 |
+
def vit_b_16_txt() -> CLIPTextEncoder:
|
230 |
+
return _text_encoder("vit_b_16")
|
231 |
+
|
232 |
+
def vit_l_14_txt() -> CLIPTextEncoder:
|
233 |
+
return _text_encoder("vit_l_14")
|
234 |
+
|
235 |
+
def vit_l_14_336px_txt() -> CLIPTextEncoder:
|
236 |
+
return _text_encoder("vit_l_14_336px")
|
237 |
+
|
238 |
+
|
239 |
+
__all__ = [
|
240 |
+
# utils
|
241 |
+
"tokenize",
|
242 |
+
"transform",
|
243 |
+
# clip models
|
244 |
+
"resnet50_clip",
|
245 |
+
"resnet101_clip",
|
246 |
+
"resnet50x4_clip",
|
247 |
+
"resnet50x16_clip",
|
248 |
+
"resnet50x64_clip",
|
249 |
+
"vit_b_32_clip",
|
250 |
+
"vit_b_16_clip",
|
251 |
+
"vit_l_14_clip",
|
252 |
+
"vit_l_14_336px_clip",
|
253 |
+
# clip image encoders
|
254 |
+
"resnet50_img",
|
255 |
+
"resnet101_img",
|
256 |
+
"resnet50x4_img",
|
257 |
+
"resnet50x16_img",
|
258 |
+
"resnet50x64_img",
|
259 |
+
"vit_b_32_img",
|
260 |
+
"vit_b_16_img",
|
261 |
+
"vit_l_14_img",
|
262 |
+
"vit_l_14_336px_img",
|
263 |
+
# clip text encoders
|
264 |
+
"resnet50_txt",
|
265 |
+
"resnet101_txt",
|
266 |
+
"resnet50x4_txt",
|
267 |
+
"resnet50x16_txt",
|
268 |
+
"resnet50x64_txt",
|
269 |
+
"vit_b_32_txt",
|
270 |
+
"vit_b_16_txt",
|
271 |
+
"vit_l_14_txt",
|
272 |
+
"vit_l_14_336px_txt",
|
273 |
+
]
|
models/clip/_clip/blocks.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import OrderedDict
|
5 |
+
from typing import Optional, Iterable
|
6 |
+
|
7 |
+
|
8 |
+
class LayerNorm(nn.LayerNorm):
|
9 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
10 |
+
|
11 |
+
def forward(self, x: Tensor):
|
12 |
+
orig_type = x.dtype
|
13 |
+
ret = super().forward(x.type(torch.float32))
|
14 |
+
return ret.type(orig_type)
|
15 |
+
|
16 |
+
|
17 |
+
class QuickGELU(nn.Module):
|
18 |
+
def forward(self, x: Tensor):
|
19 |
+
return x * torch.sigmoid(1.702 * x)
|
20 |
+
|
21 |
+
|
22 |
+
class ResidualAttentionBlock(nn.Module):
|
23 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: Tensor = None):
|
24 |
+
super().__init__()
|
25 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
26 |
+
self.ln_1 = LayerNorm(d_model)
|
27 |
+
self.mlp = nn.Sequential(OrderedDict([
|
28 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
29 |
+
("gelu", QuickGELU()),
|
30 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
31 |
+
]))
|
32 |
+
self.ln_2 = LayerNorm(d_model)
|
33 |
+
self.attn_mask = attn_mask
|
34 |
+
|
35 |
+
def attention(self, x: Tensor):
|
36 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
37 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
38 |
+
|
39 |
+
def forward(self, x: Tensor) -> Tensor:
|
40 |
+
x = x + self.attention(self.ln_1(x))
|
41 |
+
x = x + self.mlp(self.ln_2(x))
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class Transformer(nn.Module):
|
46 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: Tensor = None):
|
47 |
+
super().__init__()
|
48 |
+
self.width = width
|
49 |
+
self.layers = layers
|
50 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
51 |
+
|
52 |
+
def forward(self, x: Tensor):
|
53 |
+
return self.resblocks(x)
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
63 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
64 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
65 |
+
self.relu1 = nn.ReLU(inplace=True)
|
66 |
+
|
67 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
69 |
+
self.relu2 = nn.ReLU(inplace=True)
|
70 |
+
|
71 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
72 |
+
|
73 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
74 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
75 |
+
self.relu3 = nn.ReLU(inplace=True)
|
76 |
+
|
77 |
+
self.downsample = None
|
78 |
+
self.stride = stride
|
79 |
+
|
80 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
81 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
82 |
+
self.downsample = nn.Sequential(OrderedDict([
|
83 |
+
("-1", nn.AvgPool2d(stride)),
|
84 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
85 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
86 |
+
]))
|
87 |
+
|
88 |
+
def forward(self, x: Tensor):
|
89 |
+
identity = x
|
90 |
+
|
91 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
92 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
93 |
+
out = self.avgpool(out)
|
94 |
+
out = self.bn3(self.conv3(out))
|
95 |
+
|
96 |
+
if self.downsample is not None:
|
97 |
+
identity = self.downsample(x)
|
98 |
+
|
99 |
+
out += identity
|
100 |
+
out = self.relu3(out)
|
101 |
+
return out
|
102 |
+
|
103 |
+
|
104 |
+
class AttentionPool2d(nn.Module):
|
105 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
106 |
+
super().__init__()
|
107 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
108 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
109 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
110 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
111 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
112 |
+
self.num_heads = num_heads
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
116 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
117 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
118 |
+
x, _ = F.multi_head_attention_forward(
|
119 |
+
query=x[:1], key=x, value=x,
|
120 |
+
embed_dim_to_check=x.shape[-1],
|
121 |
+
num_heads=self.num_heads,
|
122 |
+
q_proj_weight=self.q_proj.weight,
|
123 |
+
k_proj_weight=self.k_proj.weight,
|
124 |
+
v_proj_weight=self.v_proj.weight,
|
125 |
+
in_proj_weight=None,
|
126 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
127 |
+
bias_k=None,
|
128 |
+
bias_v=None,
|
129 |
+
add_zero_attn=False,
|
130 |
+
dropout_p=0,
|
131 |
+
out_proj_weight=self.c_proj.weight,
|
132 |
+
out_proj_bias=self.c_proj.bias,
|
133 |
+
use_separate_proj_weight=True,
|
134 |
+
training=self.training,
|
135 |
+
need_weights=False
|
136 |
+
)
|
137 |
+
return x.squeeze(0)
|
models/clip/_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
models/clip/_clip/configs/clip_image_encoder_resnet101.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
4,
|
7 |
+
23,
|
8 |
+
3
|
9 |
+
],
|
10 |
+
"vision_width": 64,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"vision_heads": 32
|
13 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_resnet50.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
3
|
9 |
+
],
|
10 |
+
"vision_width": 64,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"vision_heads": 32
|
13 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_resnet50x16.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 384,
|
4 |
+
"vision_layers": [
|
5 |
+
6,
|
6 |
+
8,
|
7 |
+
18,
|
8 |
+
8
|
9 |
+
],
|
10 |
+
"vision_width": 96,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"vision_heads": 48
|
13 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_resnet50x4.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"image_resolution": 288,
|
4 |
+
"vision_layers": [
|
5 |
+
4,
|
6 |
+
6,
|
7 |
+
10,
|
8 |
+
6
|
9 |
+
],
|
10 |
+
"vision_width": 80,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"vision_heads": 40
|
13 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_resnet50x64.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 448,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
15,
|
7 |
+
36,
|
8 |
+
10
|
9 |
+
],
|
10 |
+
"vision_width": 128,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"vision_heads": 64
|
13 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_vit_b_16.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 16,
|
7 |
+
"vision_heads": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_vit_b_32.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 32,
|
7 |
+
"vision_heads": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_vit_l_14.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14,
|
7 |
+
"vision_heads": 16
|
8 |
+
}
|
models/clip/_clip/configs/clip_image_encoder_vit_l_14_336px.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 336,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14,
|
7 |
+
"vision_heads": 16
|
8 |
+
}
|
models/clip/_clip/configs/clip_resnet101.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
4,
|
7 |
+
23,
|
8 |
+
3
|
9 |
+
],
|
10 |
+
"vision_width": 64,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"context_length": 77,
|
13 |
+
"vocab_size": 49408,
|
14 |
+
"transformer_width": 512,
|
15 |
+
"transformer_heads": 8,
|
16 |
+
"transformer_layers": 12
|
17 |
+
}
|
models/clip/_clip/configs/clip_resnet50.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
3
|
9 |
+
],
|
10 |
+
"vision_width": 64,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"context_length": 77,
|
13 |
+
"vocab_size": 49408,
|
14 |
+
"transformer_width": 512,
|
15 |
+
"transformer_heads": 8,
|
16 |
+
"transformer_layers": 12
|
17 |
+
}
|
models/clip/_clip/configs/clip_resnet50x16.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 384,
|
4 |
+
"vision_layers": [
|
5 |
+
6,
|
6 |
+
8,
|
7 |
+
18,
|
8 |
+
8
|
9 |
+
],
|
10 |
+
"vision_width": 96,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"context_length": 77,
|
13 |
+
"vocab_size": 49408,
|
14 |
+
"transformer_width": 768,
|
15 |
+
"transformer_heads": 12,
|
16 |
+
"transformer_layers": 12
|
17 |
+
}
|
models/clip/_clip/configs/clip_resnet50x4.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"image_resolution": 288,
|
4 |
+
"vision_layers": [
|
5 |
+
4,
|
6 |
+
6,
|
7 |
+
10,
|
8 |
+
6
|
9 |
+
],
|
10 |
+
"vision_width": 80,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"context_length": 77,
|
13 |
+
"vocab_size": 49408,
|
14 |
+
"transformer_width": 640,
|
15 |
+
"transformer_heads": 10,
|
16 |
+
"transformer_layers": 12
|
17 |
+
}
|
models/clip/_clip/configs/clip_resnet50x64.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 448,
|
4 |
+
"vision_layers": [
|
5 |
+
3,
|
6 |
+
15,
|
7 |
+
36,
|
8 |
+
10
|
9 |
+
],
|
10 |
+
"vision_width": 128,
|
11 |
+
"vision_patch_size": null,
|
12 |
+
"context_length": 77,
|
13 |
+
"vocab_size": 49408,
|
14 |
+
"transformer_width": 1024,
|
15 |
+
"transformer_heads": 16,
|
16 |
+
"transformer_layers": 12
|
17 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_resnet101.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 512,
|
6 |
+
"transformer_heads": 8,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_resnet50.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 512,
|
6 |
+
"transformer_heads": 8,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_resnet50x16.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 768,
|
6 |
+
"transformer_heads": 12,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_resnet50x4.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 640,
|
6 |
+
"transformer_heads": 10,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_resnet50x64.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 1024,
|
6 |
+
"transformer_heads": 16,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_vit_b_16.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 512,
|
6 |
+
"transformer_heads": 8,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_vit_b_32.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 512,
|
6 |
+
"transformer_heads": 8,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_vit_l_14.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 768,
|
6 |
+
"transformer_heads": 12,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_text_encoder_vit_l_14_336px.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"context_length": 77,
|
4 |
+
"vocab_size": 49408,
|
5 |
+
"transformer_width": 768,
|
6 |
+
"transformer_heads": 12,
|
7 |
+
"transformer_layers": 12
|
8 |
+
}
|
models/clip/_clip/configs/clip_vit_b_16.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 16,
|
7 |
+
"context_length": 77,
|
8 |
+
"vocab_size": 49408,
|
9 |
+
"transformer_width": 512,
|
10 |
+
"transformer_heads": 8,
|
11 |
+
"transformer_layers": 12
|
12 |
+
}
|
models/clip/_clip/configs/clip_vit_b_32.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 32,
|
7 |
+
"context_length": 77,
|
8 |
+
"vocab_size": 49408,
|
9 |
+
"transformer_width": 512,
|
10 |
+
"transformer_heads": 8,
|
11 |
+
"transformer_layers": 12
|
12 |
+
}
|