Yiming-M commited on
Commit
a7dedf9
·
1 Parent(s): 2a347d3

2025-07-31 18:59 🐣

Browse files
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MacOS
2
+ **/.DS_Store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: ZIP B
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.39.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  short_description: The crowd counting model ZIP-B
12
  ---
 
1
  ---
2
+ title: ZIP
3
+ emoji: 🔢
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.39.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  short_description: The crowd counting model ZIP-B
12
  ---
app.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms.functional as TF
5
+
6
+ from torch import Tensor
7
+ import spaces
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import gradio as gr
12
+ from matplotlib import cm
13
+ from huggingface_hub import hf_hub_download
14
+ from warnings import warn
15
+
16
+ from models import get_model
17
+
18
+
19
+ mean = (0.485, 0.456, 0.406)
20
+ std = (0.229, 0.224, 0.225)
21
+ alpha = 0.8
22
+ EPS = 1e-8
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+
26
+ pretrained_datasets = {
27
+ "ZIP-B": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF" "NWPU-Crowd"],
28
+ "ZIP-S": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
29
+ "ZIP-T": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
30
+ "ZIP-N": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
31
+ "ZIP-P": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
32
+ }
33
+
34
+ # -----------------------------
35
+ # Define the model architecture
36
+ # -----------------------------
37
+ def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
38
+ """ Load the model weights from the Hugging Face Hub."""
39
+ global loaded_model
40
+ # Build model
41
+
42
+ model_info_path = hf_hub_download(
43
+ repo_id=f"Yiming-M/{variant}",
44
+ filename=f"checkpoints/{dataset}/best_{metric}.pth",
45
+ )
46
+
47
+ model = get_model(model_info_path=model_info_path)
48
+ model.eval()
49
+ loaded_model = model
50
+
51
+
52
+ def _calc_size(
53
+ img_w: int,
54
+ img_h: int,
55
+ min_size: int,
56
+ max_size: int,
57
+ base: int = 32
58
+ ):
59
+ """
60
+ This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size).
61
+
62
+ Args:
63
+ img_w (int): The width of the image.
64
+ img_h (int): The height of the image.
65
+ min_size (int): The minimum size of the edges of the image.
66
+ max_size (int): The maximum size of the edges of the image.
67
+ # base (int): The base number to which the new size should be a multiple of.
68
+ """
69
+ assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}"
70
+ if max_size != float("inf"):
71
+ assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided"
72
+
73
+ assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})"
74
+
75
+ aspect_ratios = (img_w / img_h, img_h / img_w)
76
+ if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio
77
+ if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize
78
+ ratio = 1.
79
+ elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size
80
+ ratio = min_size / min(img_w, img_h)
81
+ else: # larger than the maximum size, resize to the maximum size
82
+ ratio = max_size / max(img_w, img_h)
83
+
84
+ new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base)
85
+ new_w = max(min_size, min(max_size, new_w))
86
+ new_h = max(min_size, min(max_size, new_h))
87
+ return new_w, new_h
88
+
89
+ else: # impossible to resize and preserve the aspect ratio
90
+ msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size."
91
+ warn(msg)
92
+ return _calc_size(img_w, img_h, min_size, float("inf"), base)
93
+
94
+
95
+ # -----------------------------
96
+ # Preprocessing function
97
+ # -----------------------------
98
+ # Adjust the image transforms to match what your model expects.
99
+ def transform(image: Image.Image, dataset_name: str) -> Tensor:
100
+ assert isinstance(image, Image.Image), "Input must be a PIL Image"
101
+ image_tensor = TF.to_tensor(image)
102
+
103
+ if dataset_name == "sha":
104
+ min_size = 448
105
+ max_size = float("inf")
106
+ elif dataset_name == "shb":
107
+ min_size = 448
108
+ max_size = float("inf")
109
+ elif dataset_name == "qnrf":
110
+ min_size = 448
111
+ max_size = 2048
112
+ elif dataset_name == "nwpu":
113
+ min_size = 448
114
+ max_size = 3072
115
+
116
+ image_height, image_width = image_tensor.shape[-2:]
117
+ new_width, new_height = _calc_size(
118
+ img_w=image_width,
119
+ img_h=image_height,
120
+ min_size=min_size,
121
+ max_size=max_size,
122
+ base=32
123
+ )
124
+ if new_height != image_height or new_width != image_width:
125
+ image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True)
126
+
127
+ image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
128
+ return image_tensor.unsqueeze(0) # Add batch dimension
129
+
130
+
131
+ def _sliding_window_predict(
132
+ model: nn.Module,
133
+ image: Tensor,
134
+ window_size: int,
135
+ stride: int,
136
+ max_num_windows: int = 256
137
+ ):
138
+ assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}"
139
+ window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
140
+ stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
141
+ window_size = tuple(window_size)
142
+ stride = tuple(stride)
143
+ assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}"
144
+ assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}"
145
+ assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}"
146
+
147
+ image_height, image_width = image.shape[-2:]
148
+ window_height, window_width = window_size
149
+ assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}"
150
+ stride_height, stride_width = stride
151
+
152
+ num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
153
+ num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
154
+
155
+ if hasattr(model, "block_size"):
156
+ block_size = model.block_size
157
+ elif hasattr(model, "module") and hasattr(model.module, "block_size"):
158
+ block_size = model.module.block_size
159
+ else:
160
+ raise ValueError("Model must have block_size attribute")
161
+ assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}"
162
+
163
+ windows = []
164
+ for i in range(num_rows):
165
+ for j in range(num_cols):
166
+ x_start, y_start = i * stride_height, j * stride_width
167
+ x_end, y_end = x_start + window_height, y_start + window_width
168
+ if x_end > image_height:
169
+ x_start, x_end = image_height - window_height, image_height
170
+ if y_end > image_width:
171
+ y_start, y_end = image_width - window_width, image_width
172
+
173
+ window = image[:, :, x_start:x_end, y_start:y_end]
174
+ windows.append(window)
175
+
176
+ windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w)
177
+
178
+ model.eval()
179
+ pi_maps, lambda_maps = [], []
180
+ for i in range(0, len(windows), max_num_windows):
181
+ with torch.no_grad():
182
+ image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))])
183
+ pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats)
184
+ pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
185
+ lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
186
+
187
+ pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats
188
+ pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp()
189
+
190
+ pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
191
+ lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
192
+
193
+ pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
194
+ lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
195
+
196
+ lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
197
+ pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
198
+
199
+ pi_maps.append(pi_map.cpu().numpy())
200
+ lambda_maps.append(lambda_map.cpu().numpy())
201
+
202
+ # assemble the density map
203
+ pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W)
204
+ lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W)
205
+ assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}"
206
+
207
+ pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
208
+ lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
209
+ count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
210
+ idx = 0
211
+ for i in range(num_rows):
212
+ for j in range(num_cols):
213
+ x_start, y_start = i * stride_height, j * stride_width
214
+ x_end, y_end = x_start + window_height, y_start + window_width
215
+ if x_end > image_height:
216
+ x_start, x_end = image_height - window_height, image_height
217
+ if y_end > image_width:
218
+ y_start, y_end = image_width - window_width, image_width
219
+
220
+ pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :]
221
+ lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :]
222
+ count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1.
223
+ idx += 1
224
+
225
+ # average the density map
226
+ pi_map /= count_map
227
+ lambda_map /= count_map
228
+
229
+ # convert to Tensor and reshape
230
+ pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
231
+ lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
232
+ return pi_map, lambda_map
233
+
234
+
235
+ # -----------------------------
236
+ # Inference function
237
+ # -----------------------------
238
+ @spaces.GPU(duration=120)
239
+ def predict(image: Image.Image, variant: str, dataset: str, metric: str):
240
+ """
241
+ Given an input image, preprocess it, run the model to obtain a density map,
242
+ compute the total crowd count, and prepare the density map for display.
243
+ """
244
+ global loaded_model
245
+
246
+ if loaded_model is None:
247
+
248
+ if dataset == "ShanghaiTech A":
249
+ dataset_name = "sha"
250
+ elif dataset == "ShanghaiTech B":
251
+ dataset_name = "shb"
252
+ elif dataset == "UCF-QNRF":
253
+ dataset_name = "qnrf"
254
+ elif dataset == "NWPU-Crowd":
255
+ dataset_name = "nwpu"
256
+
257
+ weight_path = f"Yiming-M/{variant}/checkpoints/{dataset_name}/best_{metric}.pth"
258
+ load_model(weight_path)
259
+
260
+ loaded_model.to(device)
261
+
262
+ # Preprocess the image
263
+ input_width, input_height = image.size
264
+ image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W)
265
+
266
+ input_size = loaded_model.input_size
267
+ image_height, image_width = image_tensor.shape[-2:]
268
+ aspect_ratio = image_width / image_height
269
+ if image_height < input_size:
270
+ new_height = input_size
271
+ new_width = int(new_height * aspect_ratio)
272
+ image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
273
+ image_height, image_width = new_height, new_width
274
+ if image_width < input_size:
275
+ new_width = input_size
276
+ new_height = int(new_width / aspect_ratio)
277
+ image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
278
+ image_height, image_width = new_height, new_width
279
+
280
+ with torch.no_grad():
281
+ if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction
282
+ # For ViT models with VPT
283
+ pi_map, lambda_map = _sliding_window_predict(
284
+ model=loaded_model,
285
+ image=image_tensor,
286
+ window_size=input_size,
287
+ stride=input_size
288
+ )
289
+
290
+ elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models
291
+ image_feats = loaded_model.backbone(image_tensor)
292
+ # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
293
+ pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats)
294
+ pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
295
+ lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
296
+
297
+ pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats
298
+ pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp()
299
+
300
+ pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
301
+ lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
302
+
303
+ pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
304
+ lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
305
+
306
+ lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
307
+ pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
308
+
309
+ else: # For non-CLIP models
310
+ x = loaded_model.backbone(image_tensor)
311
+ logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W)
312
+ logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W)
313
+ lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W)
314
+ pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W)
315
+
316
+
317
+ den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W)
318
+ count = den_map.sum().item()
319
+
320
+ strucrual_zero_map = F.interpolate(
321
+ pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
322
+ ).cpu().squeeze().numpy()
323
+
324
+ lambda_map = F.interpolate(
325
+ lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
326
+ ).cpu().squeeze().numpy()
327
+
328
+ den_map = F.interpolate(
329
+ den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
330
+ ).cpu().squeeze().numpy()
331
+
332
+ sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map)
333
+ complete_zero_map = strucrual_zero_map + sampling_zero_map
334
+
335
+ # Normalize maps for display purposes
336
+ def normalize_map(x: np.ndarray) -> np.ndarray:
337
+ """ Normalize the map to [0, 1] range for visualization. """
338
+ x_min = np.min(x)
339
+ x_max = np.max(x)
340
+ if x_max - x_min < EPS:
341
+ return np.zeros_like(x)
342
+ return (x - x_min) / (x_max - x_min + EPS)
343
+
344
+ strucrual_zero_map = normalize_map(strucrual_zero_map)
345
+ sampling_zero_map = normalize_map(sampling_zero_map)
346
+ lambda_map = normalize_map(lambda_map)
347
+ den_map = normalize_map(den_map)
348
+ complete_zero_map = normalize_map(complete_zero_map)
349
+
350
+ # Apply a colormap (e.g., 'jet') to get an RGBA image
351
+ colormap = cm.get_cmap("jet")
352
+
353
+ # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
354
+ den_map = (colormap(den_map) * 255).astype(np.uint8)
355
+ strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8)
356
+ sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8)
357
+ lambda_map = (colormap(lambda_map) * 255).astype(np.uint8)
358
+ complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8)
359
+
360
+ # Convert to PIL images
361
+ den_map = Image.fromarray(den_map).convert("RGBA")
362
+ strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA")
363
+ sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA")
364
+ lambda_map = Image.fromarray(lambda_map).convert("RGBA")
365
+ complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA")
366
+
367
+ # Ensure the original image is in RGBA format.
368
+ image_rgba = image.convert("RGBA")
369
+
370
+ den_map = Image.blend(image_rgba, den_map, alpha=alpha)
371
+ strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha)
372
+ sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha)
373
+ lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha)
374
+ complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha)
375
+
376
+ return image, strucrual_zero_map, sampling_zero_map, complete_zero_map, lambda_map, den_map, f"Predicted Count: {count:.2f}"
377
+
378
+
379
+ # -----------------------------
380
+ # Build Gradio Interface using Blocks for a two-column layout
381
+ # -----------------------------
382
+ with gr.Blocks() as demo:
383
+ gr.Markdown("# Crowd Counting by ZIP")
384
+ gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
385
+
386
+ with gr.Row():
387
+ with gr.Column():
388
+ # Dropdown for model variant
389
+ variant_dropdown = gr.Dropdown(
390
+ choices=list(pretrained_datasets.keys()),
391
+ value="ZIP-B",
392
+ label="Select Model Variant"
393
+ )
394
+
395
+ # Dropdown for pretrained dataset, dynamically updated based on variant
396
+ dataset_dropdown = gr.Dropdown(
397
+ choices=pretrained_datasets["ZIP-B"],
398
+ value=pretrained_datasets["ZIP-B"][0],
399
+ label="Select Pretrained Dataset"
400
+ )
401
+
402
+ # Dropdown for metric, always the same choices
403
+ metric_dropdown = gr.Dropdown(
404
+ choices=["mae", "rmse", "nae"],
405
+ value="mae",
406
+ label="Select Best Metric"
407
+ )
408
+
409
+ # Update dataset choices when variant changes
410
+ def update_dataset(variant):
411
+ choices = pretrained_datasets[variant]
412
+ return gr.Dropdown.update(
413
+ choices=choices,
414
+ value=choices[0]
415
+ )
416
+
417
+ variant_dropdown.change(
418
+ fn=update_dataset,
419
+ inputs=variant_dropdown,
420
+ outputs=dataset_dropdown
421
+ )
422
+ input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
423
+ submit_btn = gr.Button("Predict")
424
+
425
+ with gr.Column():
426
+ output_den_map = gr.Image(label="Predicted Density Map", type="pil")
427
+ output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil")
428
+ output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
429
+ output_lambda_map = gr.Image(label="Lambda Map", type="pil")
430
+ output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
431
+
432
+ output_text = gr.Textbox(label="Total Count")
433
+
434
+ submit_btn.click(
435
+ fn=predict,
436
+ inputs=[input_img, variant_dropdown, dataset_dropdown, metric_dropdown],
437
+ outputs=[input_img, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map, output_lambda_map, output_den_map, output_text]
438
+ )
439
+
440
+ gr.Examples(
441
+ examples=[
442
+ ["example1.jpg"],
443
+ ["example2.jpg"],
444
+ ["example3.jpg"],
445
+ ["example4.jpg"],
446
+ ["example5.jpg"],
447
+ ["example6.jpg"],
448
+ ["example7.jpg"],
449
+ ["example8.jpg"],
450
+ ["example9.jpg"],
451
+ ["example10.jpg"],
452
+ ["example11.jpg"],
453
+ ["example12.jpg"]
454
+ ],
455
+ inputs=input_img,
456
+ label="Try an example"
457
+ )
458
+
459
+ demo.launch()
models/__init__.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from typing import List, Tuple, Optional, Union, Dict
3
+
4
+ from .ebc import _ebc, EBC
5
+ from .clip_ebc import _clip_ebc, CLIP_EBC
6
+
7
+
8
+ def get_model(
9
+ model_info_path: str,
10
+ model_name: Optional[str] = None,
11
+ block_size: Optional[int] = None,
12
+ bins: Optional[List[Tuple[float, float]]] = None,
13
+ bin_centers: Optional[List[float]] = None,
14
+ zero_inflated: Optional[bool] = True,
15
+ # parameters for CLIP_EBC
16
+ clip_weight_name: Optional[str] = None,
17
+ num_vpt: Optional[int] = None,
18
+ vpt_drop: Optional[float] = None,
19
+ input_size: Optional[int] = None,
20
+ adapter: bool = False,
21
+ adapter_reduction: Optional[int] = None,
22
+ lora: bool = False,
23
+ lora_rank: Optional[int] = None,
24
+ lora_alpha: Optional[int] = None,
25
+ lora_dropout: Optional[float] = None,
26
+ norm: str = "none",
27
+ act: str = "none",
28
+ text_prompts: Optional[List[str]] = None
29
+ ) -> Union[EBC, CLIP_EBC]:
30
+ if os.path.exists(model_info_path):
31
+ model_info = torch.load(model_info_path, map_location="cpu", weights_only=False)
32
+
33
+ model_name = model_info["config"]["model_name"]
34
+ block_size = model_info["config"]["block_size"]
35
+ bins = model_info["config"]["bins"]
36
+ bin_centers = model_info["config"]["bin_centers"]
37
+ zero_inflated = model_info["config"]["zero_inflated"]
38
+
39
+ clip_weight_name = model_info["config"].get("clip_weight_name", None)
40
+
41
+ num_vpt = model_info["config"].get("num_vpt", None)
42
+ vpt_drop = model_info["config"].get("vpt_drop", None)
43
+
44
+
45
+ adapter = model_info["config"].get("adapter", False)
46
+ adapter_reduction = model_info["config"].get("adapter_reduction", None)
47
+
48
+ lora = model_info["config"].get("lora", False)
49
+ lora_rank = model_info["config"].get("lora_rank", None)
50
+ lora_alpha = model_info["config"].get("lora_alpha", None)
51
+ lora_dropout = model_info["config"].get("lora_dropout", None)
52
+
53
+ input_size = model_info["config"].get("input_size", None)
54
+ text_prompts = model_info["config"].get("text_prompts", None)
55
+
56
+ norm = model_info["config"].get("norm", "none")
57
+ act = model_info["config"].get("act", "none")
58
+
59
+ weights = model_info["weights"]
60
+
61
+ else:
62
+ assert model_name is not None, "model_name should be provided if model_info_path is not provided"
63
+ assert block_size is not None, "block_size should be provided"
64
+ assert bins is not None, "bins should be provided"
65
+ assert bin_centers is not None, "bin_centers should be provided"
66
+ weights = None
67
+
68
+ if "ViT" in model_name:
69
+ assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}"
70
+ assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}"
71
+
72
+ if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"):
73
+ assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}"
74
+ model = _clip_ebc(
75
+ model_name=model_name[5:],
76
+ weight_name=clip_weight_name,
77
+ block_size=block_size,
78
+ bins=bins,
79
+ bin_centers=bin_centers,
80
+ zero_inflated=zero_inflated,
81
+ num_vpt=num_vpt,
82
+ vpt_drop=vpt_drop,
83
+ input_size=input_size,
84
+ adapter=adapter,
85
+ adapter_reduction=adapter_reduction,
86
+ lora=lora,
87
+ lora_rank=lora_rank,
88
+ lora_alpha=lora_alpha,
89
+ lora_dropout=lora_dropout,
90
+ text_prompts=text_prompts,
91
+ norm=norm,
92
+ act=act
93
+ )
94
+ model_config = {
95
+ "model_name": model_name,
96
+ "block_size": block_size,
97
+ "bins": bins,
98
+ "bin_centers": bin_centers,
99
+ "zero_inflated": zero_inflated,
100
+ "clip_weight_name": clip_weight_name,
101
+ "num_vpt": num_vpt,
102
+ "vpt_drop": vpt_drop,
103
+ "input_size": input_size,
104
+ "adapter": adapter,
105
+ "adapter_reduction": adapter_reduction,
106
+ "lora": lora,
107
+ "lora_rank": lora_rank,
108
+ "lora_alpha": lora_alpha,
109
+ "lora_dropout": lora_dropout,
110
+ "text_prompts": model.text_prompts,
111
+ "norm": norm,
112
+ "act": act
113
+ }
114
+
115
+ else:
116
+ assert not adapter, "adapter for non-CLIP models is not implemented yet"
117
+ assert not lora, "lora for non-CLIP models is not implemented yet"
118
+ model = _ebc(
119
+ model_name=model_name,
120
+ block_size=block_size,
121
+ bins=bins,
122
+ bin_centers=bin_centers,
123
+ zero_inflated=zero_inflated,
124
+ num_vpt=num_vpt,
125
+ vpt_drop=vpt_drop,
126
+ input_size=input_size,
127
+ norm=norm,
128
+ act=act
129
+ )
130
+ model_config = {
131
+ "model_name": model_name,
132
+ "block_size": block_size,
133
+ "bins": bins,
134
+ "bin_centers": bin_centers,
135
+ "zero_inflated": zero_inflated,
136
+ "num_vpt": num_vpt,
137
+ "vpt_drop": vpt_drop,
138
+ "input_size": input_size,
139
+ "norm": norm,
140
+ "act": act
141
+ }
142
+
143
+ model.config = model_config
144
+ model_info = {"config": model_config, "weights": weights}
145
+
146
+ if weights is not None:
147
+ model.load_state_dict(weights)
148
+
149
+ if not os.path.exists(model_info_path):
150
+ torch.save(model_info, model_info_path)
151
+
152
+ return model
153
+
154
+
155
+ __all__ = ["get_model"]
models/clip_ebc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model import CLIP_EBC, _clip_ebc
2
+
3
+
4
+ __all__ = [
5
+ "CLIP_EBC",
6
+ "_clip_ebc",
7
+ ]
models/clip_ebc/convnext.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvAdapter
6
+ from ..utils import ConvUpsample, _get_norm_layer, _get_activation
7
+
8
+
9
+ convnext_names_and_weights = {
10
+ "convnext_base": ["laion400m_s13b_b51k"], # 107.49M
11
+ "convnext_base_w": ["laion2b_s13b_b82k", "laion2b_s13b_b82k_augreg", "laion_aesthetic_s13b_b82k"], # 107.75M
12
+ "convnext_base_w_320": ["laion_aesthetic_s13b_b82k", "laion_aesthetic_s13b_b82k_augreg"], # 107.75M
13
+ "convnext_large_d": ["laion2b_s26b_b102k_augreg"], # 217.46M
14
+ "convnext_large_d_320": ["laion2b_s29b_b131k_ft", "laion2b_s29b_b131k_ft_soup"], # 217.46M
15
+ "convnext_xxlarge": ["laion2b_s34b_b82k_augreg", "laion2b_s34b_b82k_augreg_rewind", "laion2b_s34b_b82k_augreg_soup"] # 896.88M
16
+ }
17
+
18
+ refiner_channels = {
19
+ "convnext_base": 1024,
20
+ "convnext_base_w": 1024,
21
+ "convnext_base_w_320": 1024,
22
+ "convnext_large_d": 1536,
23
+ "convnext_large_d_320": 1536,
24
+ "convnext_xxlarge": 3072,
25
+ }
26
+
27
+ refiner_groups = {
28
+ "convnext_base": 1,
29
+ "convnext_base_w": 1,
30
+ "convnext_base_w_320": 1,
31
+ "convnext_large_d": refiner_channels["convnext_large_d"] // 512, # 3
32
+ "convnext_large_d_320": refiner_channels["convnext_large_d_320"] // 512, # 3
33
+ "convnext_xxlarge": refiner_channels["convnext_xxlarge"] // 512, # 6
34
+ }
35
+
36
+
37
+
38
+ class ConvNeXt(nn.Module):
39
+ def __init__(
40
+ self,
41
+ model_name: str,
42
+ weight_name: str,
43
+ block_size: int = 16,
44
+ adapter: bool = False,
45
+ adapter_reduction: int = 4,
46
+ norm: str = "none",
47
+ act: str = "none"
48
+ ) -> None:
49
+ super(ConvNeXt, self).__init__()
50
+ assert model_name in convnext_names_and_weights, f"Model name should be one of {list(convnext_names_and_weights.keys())}, but got {model_name}."
51
+ assert weight_name in convnext_names_and_weights[model_name], f"Pretrained should be one of {convnext_names_and_weights[model_name]}, but got {weight_name}."
52
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
53
+ self.model_name, self.weight_name = model_name, weight_name
54
+ self.block_size = block_size
55
+
56
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
57
+
58
+ self.adapter = adapter
59
+ if adapter:
60
+ self.adapter_reduction = adapter_reduction
61
+ for param in model.parameters():
62
+ param.requires_grad = False
63
+
64
+ self.stem = model.trunk.stem
65
+ self.depth = len(model.trunk.stages)
66
+ for idx, stage in enumerate(model.trunk.stages):
67
+ setattr(self, f"stage{idx}", stage)
68
+ if adapter:
69
+ setattr(self, f"adapter{idx}", ConvAdapter(
70
+ in_channels=stage.blocks[-1].mlp.fc2.out_features,
71
+ bottleneck_channels=stage.blocks[-1].mlp.fc2.out_features // adapter_reduction,
72
+ ) if idx < self.depth - 1 else nn.Identity()) # No adapter for the last stage
73
+
74
+ if self.model_name in ["convnext_base", "convnext_base_w", "convnext_base_w_320", "convnext_xxlarge"]:
75
+ self.in_features, self.out_features = model.head.proj.in_features, model.head.proj.out_features
76
+ else: # "convnext_large_d", "convnext_large_d_320":
77
+ self.in_features, self.out_features = model.head.mlp.fc1.in_features, model.head.mlp.fc2.out_features
78
+
79
+ if norm == "bn":
80
+ norm_layer = nn.BatchNorm2d
81
+ elif norm == "ln":
82
+ norm_layer = nn.LayerNorm
83
+ else:
84
+ norm_layer = _get_norm_layer(model)
85
+
86
+ if act == "relu":
87
+ activation = nn.ReLU(inplace=True)
88
+ elif act == "gelu":
89
+ activation = nn.GELU()
90
+ else:
91
+ activation = _get_activation(model)
92
+
93
+ if block_size == 32:
94
+ self.refiner = ConvRefine(
95
+ in_channels=self.in_features,
96
+ out_channels=self.in_features,
97
+ norm_layer=norm_layer,
98
+ activation=activation,
99
+ groups=refiner_groups[self.model_name],
100
+ )
101
+ elif block_size == 16:
102
+ self.refiner = ConvUpsample(
103
+ in_channels=self.in_features,
104
+ out_channels=self.in_features,
105
+ norm_layer=norm_layer,
106
+ activation=activation,
107
+ groups=refiner_groups[self.model_name],
108
+ )
109
+ else: # block_size == 8
110
+ self.refiner = nn.Sequential(
111
+ ConvUpsample(
112
+ in_channels=self.in_features,
113
+ out_channels=self.in_features,
114
+ norm_layer=norm_layer,
115
+ activation=activation,
116
+ groups=refiner_groups[self.model_name],
117
+ ),
118
+ ConvUpsample(
119
+ in_channels=self.in_features,
120
+ out_channels=self.in_features,
121
+ norm_layer=norm_layer,
122
+ activation=activation,
123
+ groups=refiner_groups[self.model_name],
124
+ ),
125
+ )
126
+
127
+ def train(self, mode: bool = True):
128
+ if self.adapter and mode:
129
+ # training:
130
+ self.stem.eval()
131
+
132
+ for idx in range(self.depth):
133
+ getattr(self, f"stage{idx}").eval()
134
+ getattr(self, f"adapter{idx}").train()
135
+
136
+ self.refiner.train()
137
+
138
+ else:
139
+ # evaluation:
140
+ for module in self.children():
141
+ module.train(mode)
142
+
143
+ def forward(self, x: Tensor) -> Tensor:
144
+ x = self.stem(x)
145
+
146
+ for idx in range(self.depth):
147
+ x = getattr(self, f"stage{idx}")(x)
148
+ if self.adapter:
149
+ x = getattr(self, f"adapter{idx}")(x)
150
+
151
+ x = self.refiner(x)
152
+ return x
153
+
154
+
155
+ def _convnext(
156
+ model_name: str,
157
+ weight_name: str,
158
+ block_size: int = 16,
159
+ adapter: bool = False,
160
+ adapter_reduction: int = 4,
161
+ lora: bool = False,
162
+ lora_rank: int = 16,
163
+ lora_alpha: float = 32.0,
164
+ lora_dropout: float = 0.1,
165
+ norm: str = "none",
166
+ act: str = "none"
167
+ ) -> ConvNeXt:
168
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
169
+ model = ConvNeXt(
170
+ model_name=model_name,
171
+ weight_name=weight_name,
172
+ block_size=block_size,
173
+ adapter=adapter,
174
+ adapter_reduction=adapter_reduction,
175
+ norm=norm,
176
+ act=act
177
+ )
178
+
179
+ if lora:
180
+ target_modules = []
181
+ for name, module in model.named_modules():
182
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and "refiner" not in name:
183
+ target_modules.append(name)
184
+
185
+ lora_config = LoraConfig(
186
+ r=lora_rank,
187
+ lora_alpha=lora_alpha,
188
+ lora_dropout=lora_dropout,
189
+ bias="none",
190
+ target_modules=target_modules,
191
+ )
192
+ model = get_peft_model(model, lora_config)
193
+
194
+ # Unfreeze refiner
195
+ for name, module in model.named_modules():
196
+ if "refiner" in name:
197
+ module.requires_grad_(True)
198
+
199
+ return model
models/clip_ebc/mobileclip.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvUpsample, ConvAdapter
6
+ from ..utils import _get_norm_layer, _get_activation
7
+
8
+
9
+ mobileclip_names_and_weights = {
10
+ "MobileCLIP-S1": ["datacompdr"],
11
+ "MobileCLIP-S2": ["datacompdr"],
12
+ }
13
+
14
+
15
+ refiner_channels = {
16
+ "MobileCLIP-S1": 1024,
17
+ "MobileCLIP-S2": 1280,
18
+ }
19
+
20
+ refiner_groups = {
21
+ "MobileCLIP-S1": 2,
22
+ "MobileCLIP-S2": 2,
23
+ }
24
+
25
+
26
+ class MobileCLIP(nn.Module):
27
+ def __init__(
28
+ self,
29
+ model_name: str,
30
+ weight_name: str,
31
+ block_size: int = 16,
32
+ adapter: bool = False,
33
+ adapter_reduction: int = 4,
34
+ norm: str = "none",
35
+ act: str = "none"
36
+ ) -> None:
37
+ super().__init__()
38
+ assert model_name in mobileclip_names_and_weights, f"Model name should be one of {list(mobileclip_names_and_weights.keys())}, but got {model_name}."
39
+ assert weight_name in mobileclip_names_and_weights[model_name], f"Pretrained should be one of {mobileclip_names_and_weights[model_name]}, but got {weight_name}."
40
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
41
+ self.model_name, self.weight_name = model_name, weight_name
42
+ self.block_size = block_size
43
+
44
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
45
+
46
+ self.adapter = adapter
47
+ if adapter:
48
+ for param in model.parameters():
49
+ param.requires_grad = False
50
+
51
+ self.stem = model.trunk.stem
52
+ self.stages = model.trunk.stages
53
+
54
+ self.depth = len(model.trunk.stages)
55
+ for idx, stage in enumerate(model.trunk.stages):
56
+ if adapter:
57
+ setattr(self, f"adapter{idx}", ConvAdapter(
58
+ in_channels=stage.blocks[-1].mlp.fc2.out_channels,
59
+ bottleneck_channels=stage.blocks[-1].mlp.fc2.out_channels // adapter_reduction,
60
+ ))
61
+
62
+ self.final_conv = model.trunk.final_conv
63
+
64
+ self.in_features, self.out_features = model.trunk.head.fc.in_features, model.trunk.head.fc.out_features
65
+
66
+ # refine_block = LightConvRefine if model_name == "MobileCLIP-S1" else ConvRefine
67
+ # upsample_block = LightConvUpsample if model_name == "MobileCLIP-S1" else ConvUpsample
68
+
69
+ if norm == "bn":
70
+ norm_layer = nn.BatchNorm2d
71
+ elif norm == "ln":
72
+ norm_layer = nn.LayerNorm
73
+ else:
74
+ norm_layer = _get_norm_layer(model)
75
+
76
+ if act == "relu":
77
+ activation = nn.ReLU(inplace=True)
78
+ elif act == "gelu":
79
+ activation = nn.GELU()
80
+ else:
81
+ activation = _get_activation(model)
82
+
83
+ if block_size == 32:
84
+ self.refiner = ConvRefine(
85
+ in_channels=self.in_features,
86
+ out_channels=self.in_features,
87
+ norm_layer=norm_layer,
88
+ activation=activation,
89
+ groups=refiner_groups[model_name],
90
+ )
91
+ elif block_size == 16:
92
+ self.refiner = ConvUpsample(
93
+ in_channels=self.in_features,
94
+ out_channels=self.in_features,
95
+ norm_layer=norm_layer,
96
+ activation=activation,
97
+ groups=refiner_groups[self.model_name],
98
+ )
99
+ else: # block_size == 8
100
+ self.refiner = nn.Sequential(
101
+ ConvUpsample(
102
+ in_channels=self.in_features,
103
+ out_channels=self.in_features,
104
+ norm_layer=norm_layer,
105
+ activation=activation,
106
+ groups=refiner_groups[self.model_name],
107
+ ),
108
+ ConvUpsample(
109
+ in_channels=self.in_features,
110
+ out_channels=self.in_features,
111
+ norm_layer=norm_layer,
112
+ activation=activation,
113
+ groups=refiner_groups[self.model_name],
114
+ ),
115
+ )
116
+
117
+ def train(self, mode: bool = True):
118
+ if self.adapter and mode:
119
+ # training:
120
+ self.stem.eval()
121
+
122
+ for idx in range(self.depth):
123
+ getattr(self, f"stage{idx}").eval()
124
+ getattr(self, f"adapter{idx}").train()
125
+
126
+ self.final_conv.eval()
127
+ self.refiner.train()
128
+
129
+ else:
130
+ # evaluation:
131
+ for module in self.children():
132
+ module.train(mode)
133
+
134
+ def forward(self, x: Tensor) -> Tensor:
135
+ x = self.stem(x)
136
+
137
+ for idx in range(self.depth):
138
+ x = self.stages[idx](x)
139
+ if self.adapter:
140
+ x = getattr(self, f"adapter{idx}")(x)
141
+
142
+ x = self.final_conv(x)
143
+
144
+ x = self.refiner(x)
145
+ return x
146
+
147
+
148
+ def _mobileclip(
149
+ model_name: str,
150
+ weight_name: str,
151
+ block_size: int = 16,
152
+ adapter: bool = False,
153
+ adapter_reduction: int = 4,
154
+ lora: bool = False,
155
+ lora_rank: int = 16,
156
+ lora_alpha: float = 32.0,
157
+ lora_dropout: float = 0.1,
158
+ norm: str = "none",
159
+ act: str = "none"
160
+ ) -> MobileCLIP:
161
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
162
+ model = MobileCLIP(
163
+ model_name=model_name,
164
+ weight_name=weight_name,
165
+ block_size=block_size,
166
+ adapter=adapter,
167
+ adapter_reduction=adapter_reduction,
168
+ norm=norm,
169
+ act=act
170
+ )
171
+
172
+ if lora:
173
+ target_modules = []
174
+ for name, module in model.named_modules():
175
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
176
+ target_modules.append(name)
177
+
178
+ lora_config = LoraConfig(
179
+ r=lora_rank,
180
+ lora_alpha=lora_alpha,
181
+ lora_dropout=lora_dropout,
182
+ bias="none",
183
+ target_modules=target_modules,
184
+ )
185
+ model = get_peft_model(model, lora_config)
186
+
187
+ # Unfreeze the BN layers
188
+ for name, module in model.named_modules() and "refiner" not in name:
189
+ if isinstance(module, nn.BatchNorm2d):
190
+ module.requires_grad_(True)
191
+
192
+ # Unfreeze refiner
193
+ for name, module in model.named_modules():
194
+ if "refiner" in name:
195
+ module.requires_grad_(True)
196
+
197
+ return model
models/clip_ebc/model.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import List, Optional, Dict, Tuple
6
+ from copy import deepcopy
7
+
8
+ from .vit import vit_names_and_weights, _vit
9
+ from .convnext import convnext_names_and_weights, _convnext
10
+ from .resnet import resnet_names_and_weights, _resnet
11
+ from .mobileclip import mobileclip_names_and_weights, _mobileclip
12
+
13
+ from .utils import encode_text, optimize_text_prompts
14
+ from ..utils import conv1x1
15
+
16
+ supported_models_and_weights = deepcopy(vit_names_and_weights)
17
+ supported_models_and_weights.update(convnext_names_and_weights)
18
+ supported_models_and_weights.update(resnet_names_and_weights)
19
+ supported_models_and_weights.update(mobileclip_names_and_weights)
20
+
21
+
22
+ class CLIP_EBC(nn.Module):
23
+ def __init__(
24
+ self,
25
+ model_name: str,
26
+ weight_name: str,
27
+ block_size: Optional[int] = None,
28
+ bins: Optional[List[Tuple[float, float]]] = None,
29
+ bin_centers: Optional[List[float]] = None,
30
+ zero_inflated: Optional[bool] = True,
31
+ num_vpt: Optional[int] = None,
32
+ vpt_drop: Optional[float] = None,
33
+ input_size: Optional[int] = None,
34
+ adapter: Optional[bool] = False,
35
+ adapter_reduction: Optional[int] = None,
36
+ lora: Optional[bool] = False,
37
+ lora_rank: Optional[int] = None,
38
+ lora_alpha: Optional[float] = None,
39
+ lora_dropout: Optional[float] = None,
40
+ text_prompts: Optional[Dict[str, List[str]]] = None,
41
+ norm: Optional[str] = "none",
42
+ act: Optional[str] = "none",
43
+ ) -> None:
44
+ super().__init__()
45
+ if "mobileclip" in model_name.lower() or "vit" in model_name.lower():
46
+ model_name = model_name.replace("_", "-")
47
+ assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}."
48
+ assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}."
49
+ assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
50
+ assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
51
+ assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
52
+ bins = [(float(b[0]), float(b[1])) for b in bins]
53
+ assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
54
+
55
+ self.model_name = model_name
56
+ self.weight_name = weight_name
57
+ self.block_size = block_size
58
+ self.bins = bins
59
+ self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
60
+ self.zero_inflated = zero_inflated
61
+ self.text_prompts = text_prompts
62
+
63
+ # Image encoder
64
+ if model_name in vit_names_and_weights:
65
+ assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
66
+ vpt_drop = 0. if vpt_drop is None else vpt_drop
67
+ self.backbone = _vit(
68
+ model_name=model_name,
69
+ weight_name=weight_name,
70
+ num_vpt=num_vpt,
71
+ vpt_drop=vpt_drop,
72
+ block_size=block_size,
73
+ adapter=adapter,
74
+ adapter_reduction=adapter_reduction,
75
+ lora=lora,
76
+ lora_rank=lora_rank,
77
+ lora_alpha=lora_alpha,
78
+ lora_dropout=lora_dropout,
79
+ input_size=(input_size, input_size),
80
+ norm=norm,
81
+ act=act
82
+ )
83
+ elif model_name in convnext_names_and_weights:
84
+ self.backbone = _convnext(
85
+ model_name=model_name,
86
+ weight_name=weight_name,
87
+ block_size=block_size,
88
+ adapter=adapter,
89
+ adapter_reduction=adapter_reduction,
90
+ lora=lora,
91
+ lora_rank=lora_rank,
92
+ lora_alpha=lora_alpha,
93
+ lora_dropout=lora_dropout,
94
+ norm=norm,
95
+ act=act
96
+ )
97
+ elif model_name in resnet_names_and_weights:
98
+ self.backbone = _resnet(
99
+ model_name=model_name,
100
+ weight_name=weight_name,
101
+ block_size=block_size,
102
+ adapter=adapter,
103
+ adapter_reduction=adapter_reduction,
104
+ lora=lora,
105
+ lora_rank=lora_rank,
106
+ lora_alpha=lora_alpha,
107
+ lora_dropout=lora_dropout,
108
+ norm=norm,
109
+ act=act
110
+ )
111
+ elif model_name in mobileclip_names_and_weights:
112
+ self.backbone = _mobileclip(
113
+ model_name=model_name,
114
+ weight_name=weight_name,
115
+ block_size=block_size,
116
+ adapter=adapter,
117
+ adapter_reduction=adapter_reduction,
118
+ lora=lora,
119
+ lora_rank=lora_rank,
120
+ lora_alpha=lora_alpha,
121
+ lora_dropout=lora_dropout,
122
+ norm=norm,
123
+ act=act
124
+ )
125
+
126
+ self._build_text_feats()
127
+ self._build_head()
128
+
129
+ def _build_text_feats(self) -> None:
130
+ model_name, weight_name = self.model_name, self.weight_name
131
+ text_prompts = self.text_prompts
132
+
133
+ if text_prompts is None:
134
+ bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value
135
+ if self.zero_inflated: # separate 0 from the rest
136
+ assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}."
137
+ bins_pi = [0, (1, float("inf"))]
138
+ bins_lambda = bins[1:]
139
+ pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi)
140
+ lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda)
141
+ self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts}
142
+ pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
143
+ lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
144
+ pi_text_feats.requires_grad = False
145
+ lambda_text_feats.requires_grad = False
146
+ self.register_buffer("pi_text_feats", pi_text_feats)
147
+ self.register_buffer("lambda_text_feats", lambda_text_feats)
148
+
149
+ else:
150
+ text_prompts = optimize_text_prompts(model_name, weight_name, bins)
151
+ self.text_prompts = text_prompts
152
+ text_feats = encode_text(model_name, weight_name, text_prompts)
153
+ text_feats.requires_grad = False
154
+ self.register_buffer("text_feats", text_feats)
155
+
156
+ else:
157
+ if self.zero_inflated:
158
+ assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}."
159
+ pi_text_prompts = text_prompts["pi"]
160
+ lambda_text_prompts = text_prompts["lambda"]
161
+ pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
162
+ lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
163
+ pi_text_feats.requires_grad = False
164
+ lambda_text_feats.requires_grad = False
165
+ self.register_buffer("pi_text_feats", pi_text_feats)
166
+ self.register_buffer("lambda_text_feats", lambda_text_feats)
167
+
168
+ else:
169
+ text_feats = encode_text(model_name, weight_name, text_prompts)
170
+ text_feats.requires_grad = False
171
+ self.register_buffer("text_feats", text_feats)
172
+
173
+ def _build_head(self) -> None:
174
+ in_channels = self.backbone.in_features
175
+ out_channels = self.backbone.out_features
176
+ if self.zero_inflated:
177
+ self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
178
+ self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
179
+
180
+ self.pi_head = conv1x1(in_channels, out_channels, bias=False)
181
+ self.lambda_head = conv1x1(in_channels, out_channels, bias=False)
182
+
183
+ else:
184
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
185
+ self.head = conv1x1(in_channels, out_channels, bias=False)
186
+
187
+ def forward(self, image: Tensor):
188
+ image_feats = self.backbone(image)
189
+ # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
190
+
191
+ if self.zero_inflated:
192
+ pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats)
193
+ pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
194
+ lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
195
+
196
+ pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats
197
+ pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp()
198
+
199
+ pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
200
+ lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
201
+
202
+ pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
203
+ lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
204
+
205
+ lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
206
+
207
+ # pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros
208
+ den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W)
209
+
210
+ if self.training:
211
+ return pi_logit_map, lambda_logit_map, lambda_map, den_map
212
+ else:
213
+ return den_map
214
+
215
+ else:
216
+ image_feats = self.head(image_feats)
217
+ image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1)
218
+
219
+ text_feats = self.text_feats
220
+ logit_scale = self.logit_scale.exp()
221
+
222
+ logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image
223
+ logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W)
224
+
225
+ den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W)
226
+
227
+ if self.training:
228
+ return logit_map, den_map
229
+ else:
230
+ return den_map
231
+
232
+
233
+ def _clip_ebc(
234
+ model_name: str,
235
+ weight_name: str,
236
+ block_size: Optional[int] = None,
237
+ bins: Optional[List[Tuple[float, float]]] = None,
238
+ bin_centers: Optional[List[float]] = None,
239
+ zero_inflated: Optional[bool] = True,
240
+ num_vpt: Optional[int] = None,
241
+ vpt_drop: Optional[float] = None,
242
+ input_size: Optional[int] = None,
243
+ adapter: Optional[bool] = False,
244
+ adapter_reduction: Optional[int] = None,
245
+ lora: Optional[bool] = False,
246
+ lora_rank: Optional[int] = None,
247
+ lora_alpha: Optional[float] = None,
248
+ lora_dropout: Optional[float] = None,
249
+ text_prompts: Optional[List[str]] = None,
250
+ norm: Optional[str] = "none",
251
+ act: Optional[str] = "none",
252
+ ) -> CLIP_EBC:
253
+ return CLIP_EBC(
254
+ model_name=model_name,
255
+ weight_name=weight_name,
256
+ block_size=block_size,
257
+ bins=bins,
258
+ bin_centers=bin_centers,
259
+ zero_inflated=zero_inflated,
260
+ num_vpt=num_vpt,
261
+ vpt_drop=vpt_drop,
262
+ input_size=input_size,
263
+ adapter=adapter,
264
+ adapter_reduction=adapter_reduction,
265
+ lora=lora,
266
+ lora_rank=lora_rank,
267
+ lora_alpha=lora_alpha,
268
+ lora_dropout=lora_dropout,
269
+ text_prompts=text_prompts,
270
+ norm=norm,
271
+ act=act,
272
+ )
models/clip_ebc/resnet.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvUpsample, ConvAdapter
6
+ from ..utils import _get_norm_layer, _get_activation
7
+
8
+
9
+ resnet_names_and_weights = {
10
+ "RN50": ["openai", "yfcc15m", "cc12m"],
11
+ "RN101": ["openai", "yfcc15m", "cc12m"],
12
+ "RN50x4": ["openai", "yfcc15m", "cc12m"],
13
+ "RN50x16": ["openai", "yfcc15m", "cc12m"],
14
+ "RN50x64": ["openai", "yfcc15m", "cc12m"],
15
+ }
16
+
17
+ refiner_channels = {
18
+ "RN50": 2048,
19
+ "RN101": 2048,
20
+ "RN50x4": 2560,
21
+ "RN50x16": 3072,
22
+ "RN50x64": 4096,
23
+ }
24
+
25
+ refiner_groups = {
26
+ "RN50": refiner_channels["RN50"] // 512, # 4
27
+ "RN101": refiner_channels["RN101"] // 512, # 4
28
+ "RN50x4": refiner_channels["RN50x4"] // 512, # 5
29
+ "RN50x16": refiner_channels["RN50x16"] // 512, # 6
30
+ "RN50x64": refiner_channels["RN50x64"] // 512, # 8
31
+ }
32
+
33
+
34
+ class ResNet(nn.Module):
35
+ def __init__(
36
+ self,
37
+ model_name: str,
38
+ weight_name: str,
39
+ block_size: int = 16,
40
+ adapter: bool = False,
41
+ adapter_reduction: int = 4,
42
+ norm: str = "none",
43
+ act: str = "none"
44
+ ) -> None:
45
+ super(ResNet, self).__init__()
46
+ assert model_name in resnet_names_and_weights, f"Model name should be one of {list(resnet_names_and_weights.keys())}, but got {model_name}."
47
+ assert weight_name in resnet_names_and_weights[model_name], f"Pretrained should be one of {resnet_names_and_weights[model_name]}, but got {weight_name}."
48
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
49
+ self.model_name, self.weight_name = model_name, weight_name
50
+ self.block_size = block_size
51
+
52
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
53
+
54
+ self.adapter = adapter
55
+ if adapter:
56
+ for param in model.parameters():
57
+ param.requires_grad = False
58
+
59
+ # Stem
60
+ self.conv1 = model.conv1
61
+ self.bn1 = model.bn1
62
+ self.act1 = model.act1
63
+ self.conv2 = model.conv2
64
+ self.bn2 = model.bn2
65
+ self.act2 = model.act2
66
+ self.conv3 = model.conv3
67
+ self.bn3 = model.bn3
68
+ self.act3 = model.act3
69
+ self.avgpool = model.avgpool
70
+ # Stem: reduction = 4
71
+
72
+ # Layers
73
+ for idx in range(1, 5):
74
+ setattr(self, f"layer{idx}", getattr(model, f"layer{idx}"))
75
+ if adapter:
76
+ setattr(self, f"adapter{idx}", ConvAdapter(
77
+ in_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels,
78
+ bottleneck_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels // adapter_reduction,
79
+ ) if idx < 4 else nn.Identity()) # No adapter for the last layer
80
+
81
+ self.in_features = model.attnpool.c_proj.weight.shape[1]
82
+ self.out_features = model.attnpool.c_proj.weight.shape[0]
83
+
84
+ if norm == "bn":
85
+ norm_layer = nn.BatchNorm2d
86
+ elif norm == "ln":
87
+ norm_layer = nn.LayerNorm
88
+ else:
89
+ norm_layer = _get_norm_layer(model)
90
+
91
+ if act == "relu":
92
+ activation = nn.ReLU(inplace=True)
93
+ elif act == "gelu":
94
+ activation = nn.GELU()
95
+ else:
96
+ activation = _get_activation(model)
97
+
98
+ if block_size == 32:
99
+ self.refiner = ConvRefine(
100
+ in_channels=self.in_features,
101
+ out_channels=self.in_features,
102
+ norm_layer=norm_layer,
103
+ activation=activation,
104
+ groups=refiner_groups[self.model_name],
105
+ )
106
+ elif block_size == 16:
107
+ self.refiner = ConvUpsample(
108
+ in_channels=self.in_features,
109
+ out_channels=self.in_features,
110
+ norm_layer=norm_layer,
111
+ activation=activation,
112
+ groups=refiner_groups[self.model_name],
113
+ )
114
+ else: # block_size == 8
115
+ self.refiner = nn.Sequential(
116
+ ConvUpsample(
117
+ in_channels=self.in_features,
118
+ out_channels=self.in_features,
119
+ norm_layer=norm_layer,
120
+ activation=activation,
121
+ groups=refiner_groups[self.model_name],
122
+ ),
123
+ ConvUpsample(
124
+ in_channels=self.in_features,
125
+ out_channels=self.in_features,
126
+ norm_layer=norm_layer,
127
+ activation=activation,
128
+ groups=refiner_groups[self.model_name],
129
+ ),
130
+ )
131
+
132
+ def train(self, mode: bool = True):
133
+ if self.adapter and mode:
134
+ # training:
135
+ self.conv1.eval()
136
+ self.bn1.eval()
137
+ self.act1.eval()
138
+ self.conv2.eval()
139
+ self.bn2.eval()
140
+ self.act2.eval()
141
+ self.conv3.eval()
142
+ self.bn3.eval()
143
+ self.act3.eval()
144
+ self.avgpool.eval()
145
+
146
+ for idx in range(1, 5):
147
+ getattr(self, f"layer{idx}").eval()
148
+ getattr(self, f"adapter{idx}").train()
149
+
150
+ self.refiner.train()
151
+
152
+ else:
153
+ # evaluation:
154
+ for module in self.children():
155
+ module.train(mode)
156
+
157
+ def stem(self, x: Tensor) -> Tensor:
158
+ x = self.act1(self.bn1(self.conv1(x)))
159
+ x = self.act2(self.bn2(self.conv2(x)))
160
+ x = self.act3(self.bn3(self.conv3(x)))
161
+ x = self.avgpool(x)
162
+ return x
163
+
164
+ def forward(self, x: Tensor) -> Tensor:
165
+ x = self.stem(x)
166
+
167
+ x = self.layer1(x)
168
+ if self.adapter:
169
+ x = self.adapter1(x)
170
+
171
+ x = self.layer2(x)
172
+ if self.adapter:
173
+ x = self.adapter2(x)
174
+
175
+ x = self.layer3(x)
176
+ if self.adapter:
177
+ x = self.adapter3(x)
178
+
179
+ x = self.layer4(x)
180
+ if self.adapter:
181
+ x = self.adapter4(x)
182
+
183
+ x = self.refiner(x)
184
+ return x
185
+
186
+
187
+ def _resnet(
188
+ model_name: str,
189
+ weight_name: str,
190
+ block_size: int = 16,
191
+ adapter: bool = False,
192
+ adapter_reduction: int = 4,
193
+ lora: bool = False,
194
+ lora_rank: int = 16,
195
+ lora_alpha: float = 32.0,
196
+ lora_dropout: float = 0.1,
197
+ norm: str = "none",
198
+ act: str = "none"
199
+ ) -> ResNet:
200
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
201
+ model = ResNet(
202
+ model_name=model_name,
203
+ weight_name=weight_name,
204
+ block_size=block_size,
205
+ adapter=adapter,
206
+ adapter_reduction=adapter_reduction,
207
+ norm=norm,
208
+ act=act
209
+ )
210
+
211
+ if lora:
212
+ target_modules = []
213
+ for name, module in model.named_modules():
214
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
215
+ target_modules.append(name)
216
+
217
+ lora_config = LoraConfig(
218
+ r=lora_rank,
219
+ lora_alpha=lora_alpha,
220
+ lora_dropout=lora_dropout,
221
+ bias="none",
222
+ target_modules=target_modules,
223
+ )
224
+ model = get_peft_model(model, lora_config)
225
+
226
+ # Unfreeze BN layers
227
+ for name, module in model.named_modules():
228
+ if isinstance(module, nn.BatchNorm2d) and "refiner" not in name:
229
+ module.requires_grad_(True)
230
+
231
+ # Unfreeze refiner
232
+ for name, module in model.named_modules():
233
+ if "refiner" in name:
234
+ module.requires_grad_(True)
235
+
236
+ return model
models/clip_ebc/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+ import torch.nn.functional as F
4
+ import open_clip
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from typing import Union, Tuple, List
8
+
9
+
10
+ num_to_word = {
11
+ "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine",
12
+ "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen",
13
+ "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine",
14
+ "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine",
15
+ "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine",
16
+ "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine",
17
+ "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine",
18
+ "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine",
19
+ "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine",
20
+ "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine",
21
+ "100": "one hundred"
22
+ }
23
+
24
+ prefixes = [
25
+ "",
26
+ "A photo of", "A block of", "An image of", "A picture of",
27
+ "There are",
28
+ "The image contains", "The photo contains", "The picture contains",
29
+ "The image shows", "The photo shows", "The picture shows",
30
+ ]
31
+ arabic_numeral = [True, False]
32
+ compares = [
33
+ "more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to",
34
+ "at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to",
35
+ "over", "above", "beyond", "exceeding", "surpassing",
36
+ ]
37
+ suffixes = [
38
+ "people", "persons", "individuals", "humans", "faces", "heads", "figures", "",
39
+ ]
40
+
41
+
42
+ def num2word(num: Union[int, str]) -> str:
43
+ """
44
+ Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc.
45
+ """
46
+ num = str(int(num))
47
+ return num_to_word.get(num, num)
48
+
49
+
50
+ def format_count(
51
+ bins: List[Union[float, Tuple[float, float]]],
52
+ ) -> List[List[str]]:
53
+ text_prompts = []
54
+ for prefix in prefixes:
55
+ for numeral in arabic_numeral:
56
+ for compare in compares:
57
+ for suffix in suffixes:
58
+ prompts = []
59
+ for bin in bins:
60
+ if isinstance(bin, (int, float)): # count is a single number
61
+ count = int(bin)
62
+ if count == 0 or count == 1:
63
+ count = num2word(count) if not numeral else count
64
+ prefix_ = "There is" if prefix == "There are" else prefix
65
+ suffix_ = "person" if suffix == "people" else suffix[:-1]
66
+ prompt = f"{prefix_} {count} {suffix_}"
67
+ else: # count > 1
68
+ count = num2word(count) if not numeral else count
69
+ prompt = f"{prefix} {count} {suffix}"
70
+
71
+ elif bin[1] == float("inf"): # count is (lower_bound, inf)
72
+ count = int(bin[0])
73
+ count = num2word(count) if not numeral else count
74
+ prompt = f"{prefix} {compare} {count} {suffix}"
75
+
76
+ else: # bin is (lower_bound, upper_bound)
77
+ left, right = int(bin[0]), int(bin[1])
78
+ left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right
79
+ prompt = f"{prefix} between {left} and {right} {suffix}"
80
+
81
+ # Remove starting and trailing whitespaces
82
+ prompt = prompt.strip() + "."
83
+
84
+ prompts.append(prompt)
85
+
86
+ text_prompts.append(prompts)
87
+
88
+ return text_prompts
89
+
90
+
91
+ def encode_text(
92
+ model_name: str,
93
+ weight_name: str,
94
+ text: List[str]
95
+ ) -> Tensor:
96
+ if torch.cuda.is_available():
97
+ device = torch.device("cuda")
98
+ elif torch.mps.is_available():
99
+ device = torch.device("mps")
100
+ else:
101
+ device = torch.device("cpu")
102
+ text = open_clip.get_tokenizer(model_name)(text).to(device)
103
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device)
104
+ model.eval()
105
+ with torch.no_grad():
106
+ text_feats = model.encode_text(text)
107
+ text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu()
108
+ return text_feats
109
+
110
+
111
+ def optimize_text_prompts(
112
+ model_name: str,
113
+ weight_name: str,
114
+ flat_bins: List[Union[float, Tuple[float, float]]],
115
+ batch_size: int = 1024,
116
+ ) -> List[str]:
117
+ text_prompts = format_count(flat_bins)
118
+
119
+ # Find the template that has the smallest average similarity of bin prompts.
120
+ print("Finding the best setup for text prompts...")
121
+ text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list
122
+ text_feats = []
123
+ for i in tqdm(range(0, len(text_prompts_), batch_size)):
124
+ text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))]))
125
+ text_feats = torch.cat(text_feats, dim=0)
126
+
127
+ sims = []
128
+ for idx, prompts in enumerate(text_prompts):
129
+ text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)]
130
+ sim = torch.mm(text_feats_, text_feats_.T)
131
+ sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item()
132
+ sims.append(sim)
133
+
134
+ optimal_prompts = text_prompts[np.argmin(sims)]
135
+ sim = sims[np.argmin(sims)]
136
+ print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})")
137
+ return optimal_prompts
models/clip_ebc/vit.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import math
4
+ from einops import rearrange
5
+ import open_clip
6
+ from peft import get_peft_model, LoraConfig
7
+ from typing import Optional, Tuple
8
+
9
+ from ..utils import interpolate_pos_embed, ViTAdapter
10
+ # from ..utils import TransformerRefine, TransformerDownsample, TransformerUpsample
11
+ from ..utils import ConvRefine, ConvDownsample, ConvUpsample
12
+ from ..utils import _get_norm_layer, _get_activation
13
+
14
+
15
+ vit_names_and_weights = {
16
+ "ViT-B-32": [
17
+ "openai",
18
+ "laion400m_e31", "laion400m_e32", "laion2b_e16", "laion2b_s34b_b79k",
19
+ "datacomp_xl_s13b_b90k", "datacomp_m_s128m_b4k", "datacomp_s_s13m_b4k",
20
+ "commonpool_m_clip_s128m_b4k", "commonpool_m_laion_s128m_b4k", "commonpool_m_image_s128m_b4k", "commonpool_m_text_s128m_b4k", "commonpool_m_basic_s128m_b4k", "commonpool_m_s128m_b4k",
21
+ "commonpool_s_clip_s13m_b4k", "commonpool_s_laion_s13m_b4k", "commonpool_s_image_s13m_b4k", "commonpool_s_text_s13m_b4k", "commonpool_s_basic_s13m_b4k", "commonpool_s_s13m_b4k",
22
+ ],
23
+ "ViT-B_32-256": ["datacomp_s34b_b86k"],
24
+ "ViT-B-16": [
25
+ "openai",
26
+ "laion400m_e31", "laion400m_e32", "laion2b_s34b_b88k",
27
+ "datacomp_xl_s13b_b90k", "datacomp_l_s1b_b8k",
28
+ "commonpool_l_clip_s1b_b8k", "commonpool_l_laion_s1b_b8k", "commonpool_l_image_s1b_b8k", "commonpool_l_text_s1b_b8k", "commonpool_l_basic_s1b_b8k", "commonpool_l_s1b_b8k",
29
+ "dfn2b"
30
+ ],
31
+ "ViT-L-14": [
32
+ "openai",
33
+ "laion400m_e31", "laion400m_e32", "laion2b_s32b_b82k",
34
+ "datacomp_xl_s13b_b90k",
35
+ "commonpool_xl_clip_s13b_b90k", "commonpool_xl_laion_s13b_b90k", "commonpool_xl_s13b_b90k"
36
+ ],
37
+ "ViT-L-14-336": ["openai"],
38
+ "ViT-H-14": ["laion2b_s32b_b79k"],
39
+ "ViT-g-14": ["laion2b_s12b_b42k", "laion2b_s34b_b88k"],
40
+ "ViT-bigG-14": ["laion2b_s39b_b160k"],
41
+ }
42
+
43
+
44
+ refiner_channels = {
45
+ "ViT-B-32": 768,
46
+ "ViT-B-32-256": 768,
47
+ "ViT-B-16": 768,
48
+ "ViT-L-14": 1024,
49
+ "ViT-L-14-336": 1024,
50
+ "ViT-H-14": 1280,
51
+ "ViT-g-14": 1408,
52
+ "ViT-bigG-14": 1664,
53
+ }
54
+
55
+ refiner_groups = {
56
+ "ViT-B-32": 1,
57
+ "ViT-B-32-256": 1,
58
+ "ViT-B-16": 1,
59
+ "ViT-L-14": 1,
60
+ "ViT-L-14-336": 1,
61
+ "ViT-H-14": 1,
62
+ "ViT-g-14": refiner_channels["ViT-g-14"] // 704, # 2
63
+ "ViT-bigG-14": refiner_channels["ViT-bigG-14"] // 416, # 4
64
+ }
65
+
66
+
67
+
68
+ class ViT(nn.Module):
69
+ def __init__(
70
+ self,
71
+ model_name: str,
72
+ weight_name: str,
73
+ block_size: int = 16,
74
+ num_vpt: int = 32,
75
+ vpt_drop: float = 0.0,
76
+ adapter: bool = False,
77
+ adapter_reduction: int = 4,
78
+ input_size: Optional[Tuple[int, int]] = None,
79
+ norm: str = "none",
80
+ act: str = "none"
81
+ ) -> None:
82
+ super(ViT, self).__init__()
83
+ assert model_name in vit_names_and_weights, f"Model name should be one of {list(vit_names_and_weights.keys())}, but got {model_name}."
84
+ assert weight_name in vit_names_and_weights[model_name], f"Pretrained should be one of {vit_names_and_weights[model_name]}, but got {weight_name}."
85
+ if adapter:
86
+ assert num_vpt is None or num_vpt == 0, "num_vpt should be None or 0 when using adapter."
87
+ assert vpt_drop is None or vpt_drop == 0.0, "vpt_drop should be None or 0.0 when using adapter."
88
+ else:
89
+ assert num_vpt > 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
90
+ assert 0.0 <= vpt_drop < 1.0, f"VPT dropout should be in [0.0, 1.0), but got {vpt_drop}."
91
+
92
+ self.model_name, self.weight_name = model_name, weight_name
93
+ self.block_size = block_size
94
+ self.num_vpt = num_vpt
95
+ self.vpt_drop = vpt_drop
96
+ self.adapter = adapter
97
+
98
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
99
+
100
+ # Always freeze the parameters of the model
101
+ for param in model.parameters():
102
+ param.requires_grad = False
103
+
104
+ # Setup the model
105
+ self.input_size = input_size if input_size is not None else model.image_size
106
+ self.pretrain_size = model.image_size
107
+ self.patch_size = model.patch_size
108
+ self.class_embedding = model.class_embedding
109
+ self.positional_embedding = model.positional_embedding
110
+ self.embed_dim = model.class_embedding.shape[-1]
111
+
112
+ self.conv1 = model.conv1
113
+ self.ln_pre = model.ln_pre
114
+ self.resblocks = model.transformer.resblocks
115
+ self.num_layers = len(self.resblocks)
116
+ self.ln_post = model.ln_post
117
+
118
+ # Setup VPT tokens
119
+ val = math.sqrt(6. / float(3 * self.patch_size[0] + self.embed_dim))
120
+ for idx in range(self.num_layers):
121
+ if self.adapter:
122
+ setattr(self, f"adapter{idx}", ViTAdapter(
123
+ in_channels=self.embed_dim,
124
+ bottleneck_channels=self.embed_dim // adapter_reduction,
125
+ ))
126
+ else:
127
+ setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.embed_dim)))
128
+ nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
129
+ setattr(self, f"vpt_drop_{idx}", nn.Dropout(self.vpt_drop))
130
+
131
+ # Adjust the positional embedding to match the new input size
132
+ self._adjust_pos_embed()
133
+
134
+ in_features, out_features = model.proj.shape
135
+ self.in_features = in_features
136
+ self.out_features = out_features
137
+
138
+ patch_size = self.patch_size[0]
139
+ if patch_size in [16, 32]:
140
+ assert block_size in [8, 16, 32], f"Patch size is 32, but got block size {block_size}."
141
+ else: # patch_size == 14
142
+ assert block_size in [7, 14, 28], f"Patch size is 14, but got block size {block_size}."
143
+
144
+ if norm == "bn":
145
+ norm_layer = nn.BatchNorm2d
146
+ elif norm == "ln":
147
+ norm_layer = nn.LayerNorm
148
+ else:
149
+ norm_layer = _get_norm_layer(model)
150
+
151
+ if act == "relu":
152
+ activation = nn.ReLU(inplace=True)
153
+ elif act == "gelu":
154
+ activation = nn.GELU()
155
+ else:
156
+ activation = _get_activation(model)
157
+
158
+ if block_size == patch_size:
159
+ self.refiner = ConvRefine(
160
+ in_channels=self.in_features,
161
+ out_channels=self.in_features,
162
+ norm_layer=norm_layer,
163
+ activation=activation,
164
+ groups=refiner_groups[self.model_name],
165
+ )
166
+
167
+ elif block_size < patch_size: # upsample
168
+ if block_size == 8 and patch_size == 32:
169
+ self.refiner = nn.Sequential(
170
+ ConvUpsample(
171
+ in_channels=self.in_features,
172
+ out_channels=self.in_features,
173
+ norm_layer=norm_layer,
174
+ activation=activation,
175
+ groups=refiner_groups[self.model_name],
176
+ ),
177
+ ConvUpsample(
178
+ in_channels=self.in_features,
179
+ out_channels=self.in_features,
180
+ norm_layer=norm_layer,
181
+ activation=activation,
182
+ groups=refiner_groups[self.model_name],
183
+ ),
184
+ )
185
+ else:
186
+ self.refiner = ConvUpsample(
187
+ in_channels=self.in_features,
188
+ out_channels=self.in_features,
189
+ norm_layer=norm_layer,
190
+ activation=activation,
191
+ groups=refiner_groups[self.model_name],
192
+ )
193
+
194
+ else: # downsample
195
+ assert block_size // patch_size == 2, f"Block size {block_size} should be 2 times the patch size {patch_size}."
196
+ self.refiner = ConvDownsample(
197
+ in_channels=self.in_features,
198
+ out_channels=self.in_features,
199
+ norm_layer=norm_layer,
200
+ activation=activation,
201
+ groups=refiner_groups[self.model_name],
202
+ )
203
+
204
+ def _adjust_pos_embed(self) -> Tensor:
205
+ """
206
+ Adjust the positional embedding to match the spatial resolution of the feature map.
207
+
208
+ Args:
209
+ orig_h, orig_w: The original spatial resolution of the image.
210
+ new_h, new_w: The new spatial resolution of the image.
211
+ """
212
+ self.positional_embedding = nn.Parameter(self._interpolate_pos_embed(self.pretrain_size[0], self.pretrain_size[1], self.input_size[0], self.input_size[1]), requires_grad=False)
213
+
214
+ def _interpolate_pos_embed(self, orig_h: int, orig_w: int, new_h: int, new_w: int) -> Tensor:
215
+ """
216
+ Interpolate the positional embedding to match the spatial resolution of the feature map.
217
+
218
+ Args:
219
+ orig_h, orig_w: The original spatial resolution of the image.
220
+ new_h, new_w: The new spatial resolution of the image.
221
+ """
222
+ if (orig_h, orig_w) == (new_h, new_w):
223
+ return self.positional_embedding
224
+
225
+ orig_h_patches, orig_w_patches = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
226
+ new_h_patches, new_w_patches = new_h // self.patch_size[0], new_w // self.patch_size[1]
227
+ class_pos_embed, patch_pos_embed = self.positional_embedding[:1, :], self.positional_embedding[1:, :]
228
+ patch_pos_embed = rearrange(patch_pos_embed, "(h w) d -> d h w", h=orig_h_patches, w=orig_w_patches)
229
+ patch_pos_embed = interpolate_pos_embed(patch_pos_embed, size=(new_h_patches, new_w_patches))
230
+ patch_pos_embed = rearrange(patch_pos_embed, "d h w -> (h w) d")
231
+ pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
232
+ return pos_embed
233
+
234
+ def train(self, mode: bool = True):
235
+ if mode:
236
+ # training:
237
+ self.conv1.eval()
238
+ self.ln_pre.eval()
239
+ self.resblocks.eval()
240
+ self.ln_post.eval()
241
+
242
+ for idx in range(self.num_layers):
243
+ getattr(self, f"vpt_drop_{idx}").train()
244
+
245
+ self.refiner.train()
246
+
247
+ else:
248
+ # evaluation:
249
+ for module in self.children():
250
+ module.train(mode)
251
+
252
+ def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
253
+ vpt = getattr(self, f"vpt_{layer}").unsqueeze(0).expand(batch_size, -1, -1).to(device) # (batch_size, num_vpt, embed_dim)
254
+ vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
255
+
256
+ return vpt
257
+
258
+ def _forward_patch_embed(self, x: Tensor) -> Tensor:
259
+ # This step performs 1) embed x into patches; 2) append the class token; 3) add positional embeddings.
260
+ assert len(x.shape) == 4, f"Expected input to have shape (batch_size, 3, height, width), but got {x.shape}"
261
+ batch_size, _, height, width = x.shape
262
+
263
+ # Step 1: Embed x into patches
264
+ x = self.conv1(x)
265
+
266
+ # Step 2: Append the class token
267
+ class_embedding = self.class_embedding.expand(batch_size, 1, -1)
268
+ x = rearrange(x, "b d h w -> b (h w) d")
269
+ x = torch.cat([class_embedding, x], dim=1)
270
+
271
+ # Step 3: Add positional embeddings
272
+ pos_embed = self._interpolate_pos_embed(orig_h=self.input_size[0], orig_w=self.input_size[1], new_h=height, new_w=width).expand(batch_size, -1, -1)
273
+ x = x + pos_embed
274
+
275
+ x = self.ln_pre(x)
276
+ return x
277
+
278
+ def _forward_vpt(self, x: Tensor, idx: int) -> Tensor:
279
+ batch_size = x.shape[0]
280
+ device = x.device
281
+
282
+ # Assemble
283
+ vpt = self._prepare_vpt(idx, batch_size, device)
284
+ x = torch.cat([
285
+ x[:, :1, :], # class token
286
+ vpt,
287
+ x[:, 1:, :] # patches
288
+ ], dim=1)
289
+
290
+ # Forward
291
+ x = self.resblocks[idx](x)
292
+
293
+ # Disassemble
294
+ x = torch.cat([
295
+ x[:, :1, :], # class token
296
+ x[:, 1 + self.num_vpt:, :] # patches
297
+ ], dim=1)
298
+
299
+ return x
300
+
301
+ def _forward_adapter(self, x: Tensor, idx: int) -> Tensor:
302
+ return getattr(self, f"adapter{idx}")(x)
303
+
304
+ def forward_encoder(self, x: Tensor) -> Tensor:
305
+ x = self._forward_patch_embed(x)
306
+ for idx in range(self.num_layers):
307
+ x = self._forward_adapter(x, idx) if self.adapter else self._forward_vpt(x, idx)
308
+ x = self.ln_post(x)
309
+ return x
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ orig_h, orig_w = x.shape[-2:]
313
+ num_patches_h, num_patches_w = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
314
+ x = self.forward_encoder(x)
315
+ x = x[:, 1:, :] # remove the class token
316
+ x = rearrange(x, "b (h w) d -> b d h w", h=num_patches_h, w=num_patches_w)
317
+
318
+ x = self.refiner(x)
319
+ return x
320
+
321
+
322
+ def _vit(
323
+ model_name: str,
324
+ weight_name: str,
325
+ block_size: int = 16,
326
+ num_vpt: int = 32,
327
+ vpt_drop: float = 0.1,
328
+ adapter: bool = False,
329
+ adapter_reduction: int = 4,
330
+ lora: bool = False,
331
+ lora_rank: int = 16,
332
+ lora_alpha: float = 32.0,
333
+ lora_dropout: float = 0.1,
334
+ input_size: Optional[Tuple[int, int]] = None,
335
+ norm: str = "none",
336
+ act: str = "none"
337
+ ) -> ViT:
338
+ assert not (lora and adapter), "LoRA and adapter cannot be used together."
339
+ model = ViT(
340
+ model_name=model_name,
341
+ weight_name=weight_name,
342
+ block_size=block_size,
343
+ num_vpt=num_vpt,
344
+ vpt_drop=vpt_drop,
345
+ adapter=adapter,
346
+ adapter_reduction=adapter_reduction,
347
+ input_size=input_size,
348
+ norm=norm,
349
+ act=act
350
+ )
351
+
352
+ if lora:
353
+ target_modules = []
354
+ for name, module in model.named_modules():
355
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)) and "refiner" not in name:
356
+ target_modules.append(name)
357
+
358
+ lora_config = LoraConfig(
359
+ r=lora_rank,
360
+ lora_alpha=lora_alpha,
361
+ lora_dropout=lora_dropout,
362
+ bias="none",
363
+ target_modules=target_modules,
364
+ )
365
+ model = get_peft_model(model, lora_config)
366
+
367
+ # Unfreeze refiner
368
+ for name, module in model.named_modules():
369
+ if "refiner" in name:
370
+ module.requires_grad_(True)
371
+
372
+ return model
models/ebc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import EBC, _ebc
2
+
3
+ __all__ = ["EBC", "_ebc"]
models/ebc/cannet.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from typing import List, Optional
5
+
6
+ from .csrnet import _csrnet, _csrnet_bn
7
+ from ..utils import _init_weights
8
+
9
+ EPS = 1e-6
10
+
11
+
12
+ class ContextualModule(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ out_channels: int = 512,
17
+ scales: List[int] = [1, 2, 3, 6],
18
+ ) -> None:
19
+ super().__init__()
20
+ self.scales = scales
21
+ self.multiscale_modules = nn.ModuleList([self.__make_scale__(in_channels, size) for size in scales])
22
+ self.bottleneck = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1)
23
+ self.relu = nn.ReLU(inplace=True)
24
+ self.weight_net = nn.Conv2d(in_channels, in_channels, kernel_size=1)
25
+ self.apply(_init_weights)
26
+
27
+ def __make_weight__(self, feature: Tensor, scale_feature: Tensor) -> Tensor:
28
+ weight_feature = feature - scale_feature
29
+ weight_feature = self.weight_net(weight_feature)
30
+ return F.sigmoid(weight_feature)
31
+
32
+ def __make_scale__(self, channels: int, size: int) -> nn.Module:
33
+ return nn.Sequential(
34
+ nn.AdaptiveAvgPool2d(output_size=(size, size)),
35
+ nn.Conv2d(channels, channels, kernel_size=1, bias=False),
36
+ )
37
+
38
+ def forward(self, feature: Tensor) -> Tensor:
39
+ h, w = feature.shape[-2:]
40
+ multiscale_feats = [F.interpolate(input=scale(feature), size=(h, w), mode="bilinear") for scale in self.multiscale_modules]
41
+ weights = [self.__make_weight__(feature, scale_feature) for scale_feature in multiscale_feats]
42
+ multiscale_feats = sum([multiscale_feats[i] * weights[i] for i in range(len(weights))]) / (sum(weights) + EPS)
43
+ overall_features = torch.cat([multiscale_feats, feature], dim=1)
44
+ overall_features = self.bottleneck(overall_features)
45
+ overall_features = self.relu(overall_features)
46
+ return overall_features
47
+
48
+
49
+ class CANNet(nn.Module):
50
+ def __init__(
51
+ self,
52
+ model_name: str,
53
+ block_size: Optional[int] = None,
54
+ norm: str = "none",
55
+ act: str = "none",
56
+ scales: List[int] = [1, 2, 3, 6],
57
+ ) -> None:
58
+ super().__init__()
59
+ assert model_name in ["csrnet", "csrnet_bn"], f"Model name should be one of ['csrnet', 'csrnet_bn'], but got {model_name}."
60
+ assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
61
+ assert isinstance(scales, (tuple, list)), f"scales should be a list or tuple, got {type(scales)}."
62
+ assert len(scales) > 0, f"Expected at least one size, got {len(scales)}."
63
+ assert all([isinstance(size, int) for size in scales]), f"Expected all size to be int, got {scales}."
64
+ self.model_name = model_name
65
+ self.scales = scales
66
+
67
+ csrnet = _csrnet(block_size=block_size, norm=norm, act=act) if model_name == "csrnet" else _csrnet_bn(block_size=block_size, norm=norm, act=act)
68
+ self.block_size = csrnet.block_size
69
+
70
+ self.encoder = csrnet.encoder
71
+ self.encoder_channels = csrnet.encoder_channels
72
+ self.encoder_reduction = csrnet.encoder_reduction # feature map size compared to input size
73
+
74
+ self.refiner = nn.Sequential(
75
+ csrnet.refiner,
76
+ ContextualModule(csrnet.refine_channels, 512, scales)
77
+ )
78
+ self.refiner_channels = 512
79
+ self.refiner_reduction = csrnet.refiner_reduction # feature map size compared to input size
80
+
81
+ self.decoder = csrnet.decoder
82
+ self.decoder_channels = csrnet.decoder_channels
83
+ self.decoder_reduction = csrnet.decoder_reduction
84
+
85
+ def encode(self, x: Tensor) -> Tensor:
86
+ return self.encoder(x)
87
+
88
+ def refine(self, x: Tensor) -> Tensor:
89
+ return self.refiner(x)
90
+
91
+ def decode(self, x: Tensor) -> Tensor:
92
+ return self.decoder(x)
93
+
94
+ def forward(self, x: Tensor) -> Tensor:
95
+ x = self.encode(x)
96
+ x = self.refine(x)
97
+ x = self.decode(x)
98
+ return x
99
+
100
+
101
+ def _cannet(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
102
+ return CANNet("csrnet", block_size=block_size, norm=norm, act=act, scales=scales)
103
+
104
+ def _cannet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
105
+ return CANNet("csrnet_bn", block_size=block_size, norm=norm, act=act, scales=scales)
models/ebc/csrnet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ from torch.hub import load_state_dict_from_url
3
+ from typing import Optional
4
+
5
+ from .vgg import VGG
6
+ from .utils import make_vgg_layers, vgg_urls
7
+ from ..utils import _init_weights, ConvDownsample, _get_activation, _get_norm_layer
8
+
9
+ EPS = 1e-6
10
+
11
+
12
+ encoder_cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512]
13
+ decoder_cfg = [512, 512, 512, 256, 128]
14
+
15
+
16
+ class CSRNet(nn.Module):
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ block_size: Optional[int] = None,
21
+ norm: str = "none",
22
+ act: str = "none"
23
+ ) -> None:
24
+ super().__init__()
25
+ assert model_name in ["vgg16", "vgg16_bn"], f"Model name should be one of ['vgg16', 'vgg16_bn'], but got {model_name}."
26
+ assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
27
+ self.model_name = model_name
28
+
29
+ vgg = VGG(make_vgg_layers(encoder_cfg, in_channels=3, batch_norm="bn" in model_name, dilation=1))
30
+ vgg.load_state_dict(load_state_dict_from_url(vgg_urls[model_name]), strict=False)
31
+ self.encoder = vgg.features
32
+ self.encoder_reduction = 8
33
+ self.encoder_channels = 512
34
+ self.block_size = block_size if block_size is not None else 8
35
+
36
+ if norm == "bn":
37
+ norm_layer = nn.BatchNorm2d
38
+ elif norm == "ln":
39
+ norm_layer = nn.LayerNorm
40
+ else:
41
+ norm_layer = _get_norm_layer(vgg)
42
+
43
+ if act == "relu":
44
+ activation = nn.ReLU(inplace=True)
45
+ elif act == "gelu":
46
+ activation = nn.GELU()
47
+ else:
48
+ activation = _get_activation(vgg)
49
+
50
+ if self.block_size == self.encoder_reduction:
51
+ self.refiner = nn.Identity()
52
+ elif self.block_size > self.encoder_reduction:
53
+ if self.block_size == 32:
54
+ self.refiner = nn.Sequential(
55
+ ConvDownsample(
56
+ in_channels=self.encoder_channels,
57
+ out_channels=self.encoder_channels,
58
+ norm_layer=norm_layer,
59
+ activation=activation,
60
+ ),
61
+ ConvDownsample(
62
+ in_channels=self.encoder_channels,
63
+ out_channels=self.encoder_channels,
64
+ norm_layer=norm_layer,
65
+ activation=activation,
66
+ )
67
+ )
68
+ elif self.block_size == 16:
69
+ self.refiner = ConvDownsample(
70
+ in_channels=self.encoder_channels,
71
+ out_channels=self.encoder_channels,
72
+ norm_layer=norm_layer,
73
+ activation=activation,
74
+ )
75
+ self.refiner_channels = self.encoder_channels
76
+ self.refiner_reduction = self.block_size
77
+
78
+ decoder = make_vgg_layers(decoder_cfg, in_channels=512, batch_norm="bn" in model_name, dilation=2)
79
+ decoder.apply(_init_weights)
80
+ self.decoder = decoder
81
+ self.decoder_channels = decoder_cfg[-1]
82
+ self.decoder_reduction = self.refiner_reduction
83
+
84
+ def encode(self, x: Tensor) -> Tensor:
85
+ return self.encoder(x)
86
+
87
+ def refine(self, x: Tensor) -> Tensor:
88
+ return self.refiner(x)
89
+
90
+ def decode(self, x: Tensor) -> Tensor:
91
+ return self.decoder(x)
92
+
93
+ def forward(self, x: Tensor) -> Tensor:
94
+ x = self.encode(x)
95
+ x = self.refine(x)
96
+ x = self.decode(x)
97
+ return x
98
+
99
+
100
+ def _csrnet(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
101
+ return CSRNet("vgg16", block_size=block_size, norm=norm, act=act)
102
+
103
+ def _csrnet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
104
+ return CSRNet("vgg16_bn", block_size=block_size, norm=norm, act=act)
models/ebc/hrnet.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn.functional as F
3
+ from torch import nn, Tensor
4
+ from functools import partial
5
+ from typing import Optional
6
+
7
+ from ..utils import ConvRefine, _get_norm_layer, _get_activation
8
+
9
+
10
+ available_hrnets = [
11
+ "hrnet_w18", "hrnet_w18_small", "hrnet_w18_small_v2",
12
+ "hrnet_w30", "hrnet_w32", "hrnet_w40", "hrnet_w44", "hrnet_w48", "hrnet_w64",
13
+ ]
14
+
15
+
16
+ class HRNet(nn.Module):
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ block_size: Optional[int] = None,
21
+ norm: str = "none",
22
+ act: str = "none"
23
+ ) -> None:
24
+ super().__init__()
25
+ assert model_name in available_hrnets, f"Model name should be one of {available_hrnets}"
26
+ assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
27
+ self.model_name = model_name
28
+ self.block_size = block_size if block_size is not None else 32
29
+
30
+ model = timm.create_model(model_name, pretrained=True)
31
+
32
+ self.conv1 = model.conv1
33
+ self.bn1 = model.bn1
34
+ self.act1 = model.act1
35
+ self.conv2 = model.conv2
36
+ self.bn2 = model.bn2
37
+ self.act2 = model.act2
38
+
39
+ self.layer1 = model.layer1
40
+
41
+ self.transition1 = model.transition1
42
+ self.stage2 = model.stage2
43
+
44
+ self.transition2 = model.transition2
45
+ self.stage3 = model.stage3
46
+
47
+ self.transition3 = model.transition3
48
+ self.stage4 = model.stage4
49
+
50
+ incre_modules = model.incre_modules
51
+ downsamp_modules = model.downsamp_modules
52
+
53
+ assert len(incre_modules) == 4, f"Expected 4 incre_modules, got {len(self.incre_modules)}"
54
+ assert len(downsamp_modules) == 3, f"Expected 3 downsamp_modules, got {len(self.downsamp_modules)}"
55
+
56
+ self.out_channels_4 = incre_modules[0][0].downsample[0].out_channels
57
+ self.out_channels_8 = incre_modules[1][0].downsample[0].out_channels
58
+ self.out_channels_16 = incre_modules[2][0].downsample[0].out_channels
59
+ self.out_channels_32 = incre_modules[3][0].downsample[0].out_channels
60
+
61
+ if self.block_size == 8:
62
+ self.encoder_reduction = 8
63
+ self.encoder_channels = self.out_channels_8
64
+ self.incre_modules = incre_modules[:2]
65
+ self.downsamp_modules = downsamp_modules[:1]
66
+
67
+ self.refiner = nn.Identity()
68
+ self.refiner_reduction = 8
69
+ self.refiner_channels = self.out_channels_8
70
+
71
+ elif self.block_size == 16:
72
+ self.encoder_reduction = 16
73
+ self.encoder_channels = self.out_channels_16
74
+ self.incre_modules = incre_modules[:3]
75
+ self.downsamp_modules = downsamp_modules[:2]
76
+
77
+ self.refiner = nn.Identity()
78
+ self.refiner_reduction = 16
79
+ self.refiner_channels = self.out_channels_16
80
+
81
+ else: # self.block_size == 32
82
+ self.encoder_reduction = 32
83
+ self.encoder_channels = self.out_channels_32
84
+ self.incre_modules = incre_modules
85
+ self.downsamp_modules = downsamp_modules
86
+
87
+ self.refiner = nn.Identity()
88
+ self.refiner_reduction = 32
89
+ self.refiner_channels = self.out_channels_32
90
+
91
+ # define the decoder
92
+ if self.refiner_channels <= 512:
93
+ groups = 1
94
+ elif self.refiner_channels <= 1024:
95
+ groups = 2
96
+ elif self.refiner_channels <= 2048:
97
+ groups = 4
98
+ else:
99
+ groups = 8
100
+
101
+ if norm == "bn":
102
+ norm_layer = nn.BatchNorm2d
103
+ elif norm == "ln":
104
+ norm_layer = nn.LayerNorm
105
+ else:
106
+ norm_layer = _get_norm_layer(model)
107
+
108
+ if act == "relu":
109
+ activation = nn.ReLU(inplace=True)
110
+ elif act == "gelu":
111
+ activation = nn.GELU()
112
+ else:
113
+ activation = _get_activation(model)
114
+
115
+ decoder_block = partial(ConvRefine, groups=groups, norm_layer=norm_layer, activation=activation)
116
+ if self.refiner_channels <= 256:
117
+ self.decoder = nn.Identity()
118
+ self.decoder_channels = self.refiner_channels
119
+ elif self.refiner_channels <= 512:
120
+ self.decoder = decoder_block(self.refiner_channels, self.refiner_channels // 2)
121
+ self.decoder_channels = self.refiner_channels // 2
122
+ elif self.refiner_channels <= 1024:
123
+ self.decoder = nn.Sequential(
124
+ decoder_block(self.refiner_channels, self.refiner_channels // 2),
125
+ decoder_block(self.refiner_channels // 2, self.refiner_channels // 4),
126
+ )
127
+ self.decoder_channels = self.refiner_channels // 4
128
+ else:
129
+ self.decoder = nn.Sequential(
130
+ decoder_block(self.refiner_channels, self.refiner_channels // 2),
131
+ decoder_block(self.refiner_channels // 2, self.refiner_channels // 4),
132
+ decoder_block(self.refiner_channels // 4, self.refiner_channels // 8),
133
+ )
134
+ self.decoder_channels = self.refiner_channels // 8
135
+
136
+ self.decoder_reduction = self.refiner_reduction
137
+
138
+ def _interpolate(self, x: Tensor) -> Tensor:
139
+ # This method adjust the spatial dimensions of the input tensor so that it can be divided by 32.
140
+ if x.shape[-1] % 32 != 0 or x.shape[-2] % 32 != 0:
141
+ new_h = int(round(x.shape[-2] / 32) * 32)
142
+ new_w = int(round(x.shape[-1] / 32) * 32)
143
+ return F.interpolate(x, size=(new_h, new_w), mode="bicubic", align_corners=False)
144
+
145
+ return x
146
+
147
+
148
+ def encode(self, x: Tensor) -> Tensor:
149
+ x = self.conv1(x)
150
+ x = self.bn1(x)
151
+ x = self.act1(x)
152
+
153
+ x = self.conv2(x)
154
+ x = self.bn2(x)
155
+ x = self.act2(x)
156
+
157
+ x = self.layer1(x)
158
+
159
+ x = [t(x) for t in self.transition1]
160
+ x = self.stage2(x)
161
+
162
+ x = [t(x[-1]) if not isinstance(t, nn.Identity) else x[i] for i, t in enumerate(self.transition2)]
163
+ x = self.stage3(x)
164
+
165
+ x = [t(x[-1]) if not isinstance(t, nn.Identity) else x[i] for i, t in enumerate(self.transition3)]
166
+ x = self.stage4(x)
167
+
168
+ assert len(x) == 4, f"Expected 4 outputs, got {len(x)}"
169
+
170
+ feats = None
171
+ for i, incre in enumerate(self.incre_modules):
172
+ if feats is None:
173
+ feats = incre(x[i])
174
+ else:
175
+ down = self.downsamp_modules[i - 1] # needed for torchscript module indexing
176
+ feats = incre(x[i]) + down.forward(feats)
177
+
178
+ return feats
179
+
180
+ def refine(self, x: Tensor) -> Tensor:
181
+ return self.refiner(x)
182
+
183
+ def decode(self, x: Tensor) -> Tensor:
184
+ return self.decoder(x)
185
+
186
+ def forward(self, x: Tensor) -> Tensor:
187
+ x = self._interpolate(x)
188
+ x = self.encode(x)
189
+ x = self.refine(x)
190
+ x = self.decode(x)
191
+ return x
192
+
193
+
194
+ def _hrnet(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> HRNet:
195
+ return HRNet(model_name, block_size, norm, act)
models/ebc/model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from einops import rearrange
4
+
5
+ from typing import Tuple, Union, Dict, Optional, List
6
+ from functools import partial
7
+
8
+ from .cannet import _cannet, _cannet_bn
9
+ from .csrnet import _csrnet, _csrnet_bn
10
+ from .vgg import _vgg_encoder_decoder, _vgg_encoder
11
+ from .vit import _vit, supported_vit_backbones
12
+ from .timm_models import _timm_model
13
+ from .timm_models import regular_models as timm_regular_models, heavy_models as timm_heavy_models, light_models as timm_light_models, lighter_models as timm_lighter_models
14
+ from .hrnet import _hrnet, available_hrnets
15
+
16
+ from ..utils import conv1x1
17
+
18
+
19
+ regular_models = [
20
+ "csrnet", "csrnet_bn",
21
+ "cannet", "cannet_bn",
22
+ "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn",
23
+ "vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae",
24
+ *timm_regular_models,
25
+ *available_hrnets,
26
+ ]
27
+
28
+ heavy_models = timm_heavy_models
29
+
30
+ light_models = timm_light_models
31
+
32
+ lighter_models = timm_lighter_models
33
+
34
+ transformer_models = supported_vit_backbones
35
+
36
+ supported_models = regular_models + heavy_models + light_models + lighter_models + transformer_models
37
+
38
+
39
+
40
+ class EBC(nn.Module):
41
+ def __init__(
42
+ self,
43
+ model_name: str,
44
+ block_size: int,
45
+ bins: List[Tuple[float, float]],
46
+ bin_centers: List[float],
47
+ zero_inflated: bool = True,
48
+ num_vpt: Optional[int] = None,
49
+ vpt_drop: Optional[float] = None,
50
+ input_size: Optional[int] = None,
51
+ norm: str = "none",
52
+ act: str = "none"
53
+ ) -> None:
54
+ super().__init__()
55
+ assert model_name in supported_models, f"Model name should be one of {supported_models}, but got {model_name}."
56
+ self.model_name = model_name
57
+
58
+ if input_size is not None:
59
+ input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
60
+ assert len(input_size) == 2 and input_size[0] > 0 and input_size[1] > 0, f"Expected input_size to be a tuple of two positive integers, got {input_size}"
61
+ self.input_size = input_size
62
+
63
+ assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
64
+ assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
65
+ assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
66
+ bins = [(float(b[0]), float(b[1])) for b in bins]
67
+ assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
68
+
69
+ self.block_size = block_size
70
+ self.bins = bins
71
+ self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
72
+
73
+ self.zero_inflated = zero_inflated
74
+ self.num_vpt = num_vpt
75
+ self.vpt_drop = vpt_drop
76
+ self.input_size = input_size
77
+
78
+ self.norm = norm
79
+ self.act = act
80
+
81
+ self._build_backbone()
82
+ self._build_head()
83
+
84
+ def _build_backbone(self) -> None:
85
+ model_name = self.model_name
86
+ if model_name == "csrnet":
87
+ self.backbone = _csrnet(self.block_size, self.norm, self.act)
88
+ elif model_name == "csrnet_bn":
89
+ self.backbone = _csrnet_bn(self.block_size, self.norm, self.act)
90
+ elif model_name == "cannet":
91
+ self.backbone = _cannet(self.block_size, self.norm, self.act)
92
+ elif model_name == "cannet_bn":
93
+ self.backbone = _cannet_bn(self.block_size, self.norm, self.act)
94
+ elif model_name == "vgg11":
95
+ self.backbone = _vgg_encoder("vgg11", self.block_size, self.norm, self.act)
96
+ elif model_name == "vgg11_ae":
97
+ self.backbone = _vgg_encoder_decoder("vgg11", self.block_size, self.norm, self.act)
98
+ elif model_name == "vgg11_bn":
99
+ self.backbone = _vgg_encoder("vgg11_bn", self.block_size, self.norm, self.act)
100
+ elif model_name == "vgg11_bn_ae":
101
+ self.backbone = _vgg_encoder_decoder("vgg11_bn", self.block_size, self.norm, self.act)
102
+ elif model_name == "vgg13":
103
+ self.backbone = _vgg_encoder("vgg13", self.block_size, self.norm, self.act)
104
+ elif model_name == "vgg13_ae":
105
+ self.backbone = _vgg_encoder_decoder("vgg13", self.block_size, self.norm, self.act)
106
+ elif model_name == "vgg13_bn":
107
+ self.backbone = _vgg_encoder("vgg13_bn", self.block_size, self.norm, self.act)
108
+ elif model_name == "vgg13_bn_ae":
109
+ self.backbone = _vgg_encoder_decoder("vgg13_bn", self.block_size, self.norm, self.act)
110
+ elif model_name == "vgg16":
111
+ self.backbone = _vgg_encoder("vgg16", self.block_size, self.norm, self.act)
112
+ elif model_name == "vgg16_ae":
113
+ self.backbone = _vgg_encoder_decoder("vgg16", self.block_size, self.norm, self.act)
114
+ elif model_name == "vgg16_bn":
115
+ self.backbone = _vgg_encoder("vgg16_bn", self.block_size, self.norm, self.act)
116
+ elif model_name == "vgg16_bn_ae":
117
+ self.backbone = _vgg_encoder_decoder("vgg16_bn", self.block_size, self.norm, self.act)
118
+ elif model_name == "vgg19":
119
+ self.backbone = _vgg_encoder("vgg19", self.block_size, self.norm, self.act)
120
+ elif model_name == "vgg19_ae":
121
+ self.backbone = _vgg_encoder_decoder("vgg19", self.block_size, self.norm, self.act)
122
+ elif model_name == "vgg19_bn":
123
+ self.backbone = _vgg_encoder("vgg19_bn", self.block_size, self.norm, self.act)
124
+ elif model_name == "vgg19_bn_ae":
125
+ self.backbone = _vgg_encoder_decoder("vgg19_bn", self.block_size, self.norm, self.act)
126
+ elif model_name in supported_vit_backbones:
127
+ self.backbone = _vit(model_name, block_size=self.block_size, num_vpt=self.num_vpt, vpt_drop=self.vpt_drop, input_size=self.input_size, norm=self.norm, act=self.act)
128
+ elif model_name in available_hrnets:
129
+ self.backbone = _hrnet(model_name, block_size=self.block_size, norm=self.norm, act=self.act)
130
+ else:
131
+ self.backbone = _timm_model(model_name, self.block_size, self.norm, self.act)
132
+
133
+ def _build_head(self) -> None:
134
+ channels = self.backbone.decoder_channels
135
+ if self.zero_inflated:
136
+ self.bin_head = conv1x1(
137
+ in_channels=channels,
138
+ out_channels=len(self.bins) - 1,
139
+ )
140
+ self.pi_head = conv1x1(
141
+ in_channels=channels,
142
+ out_channels=2,
143
+ ) # this models structural 0s.
144
+ else:
145
+ self.bin_head = conv1x1(
146
+ in_channels=channels,
147
+ out_channels=len(self.bins),
148
+ )
149
+
150
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
151
+ x = self.backbone(x)
152
+
153
+ if self.zero_inflated:
154
+ logit_pi_maps = self.pi_head(x) # shape: (B, 2, H, W)
155
+ logit_maps = self.bin_head(x) # shape: (B, C, H, W)
156
+ lambda_maps = (logit_maps.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W)
157
+
158
+ # logit_pi_maps.softmax(dim=1)[:, 0] is the probability of zeros
159
+ den_maps = logit_pi_maps.softmax(dim=1)[:, 1:] * lambda_maps # expectation of the Poisson distribution
160
+
161
+ if self.training:
162
+ return logit_pi_maps, logit_maps, lambda_maps, den_maps
163
+ else:
164
+ return den_maps
165
+
166
+ else:
167
+ logit_maps = self.bin_head(x)
168
+ den_maps = (logit_maps.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True)
169
+
170
+ if self.training:
171
+ return logit_maps, den_maps
172
+ else:
173
+ return den_maps
174
+
175
+
176
+ def _ebc(
177
+ model_name: str,
178
+ block_size: int,
179
+ bins: List[Tuple[float, float]],
180
+ bin_centers: List[float],
181
+ zero_inflated: bool = True,
182
+ num_vpt: Optional[int] = None,
183
+ vpt_drop: Optional[float] = None,
184
+ input_size: Optional[int] = None,
185
+ norm: str = "none",
186
+ act: str = "none"
187
+ ) -> EBC:
188
+ return EBC(
189
+ model_name=model_name,
190
+ block_size=block_size,
191
+ bins=bins,
192
+ bin_centers=bin_centers,
193
+ zero_inflated=zero_inflated,
194
+ num_vpt=num_vpt,
195
+ vpt_drop=vpt_drop,
196
+ input_size=input_size,
197
+ norm=norm,
198
+ act=act
199
+ )
models/ebc/timm_models.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm import create_model
2
+ from torch import nn, Tensor
3
+ from typing import Optional
4
+ from functools import partial
5
+
6
+ from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample
7
+ from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample
8
+ from ..utils import ConvRefine, LightConvRefine, LighterConvRefine
9
+
10
+ regular_models = [
11
+ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
12
+ "convnext_nano", "convnext_tiny", "convnext_small", "convnext_base",
13
+ "mobilenetv4_conv_large",
14
+ ]
15
+
16
+ heavy_models = [
17
+ "convnext_large", "convnext_xlarge", "convnext_xxlarge",
18
+ ]
19
+
20
+ light_models = [
21
+ "mobilenetv1_100", "mobilenetv1_125",
22
+ "mobilenetv2_100", "mobilenetv2_140",
23
+ "mobilenetv3_large_100",
24
+ "mobilenetv4_conv_medium",
25
+
26
+ ]
27
+
28
+ lighter_models = [
29
+ "mobilenetv2_050",
30
+ "mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100",
31
+ "mobilenetv4_conv_small_050", "mobilenetv4_conv_small"
32
+ ]
33
+
34
+ supported_models = regular_models + heavy_models + light_models + lighter_models
35
+
36
+
37
+ refiner_in_channels = {
38
+ # ResNet
39
+ "resnet18": 512,
40
+ "resnet34": 512,
41
+ "resnet50": 2048,
42
+ "resnet101": 2048,
43
+ "resnet152": 2048,
44
+ # ConvNeXt
45
+ "convnext_nano": 640,
46
+ "convnext_tiny": 768,
47
+ "convnext_small": 768,
48
+ "convnext_base": 1024,
49
+ "convnext_large": 1536,
50
+ "convnext_xlarge": 2048,
51
+ "convnext_xxlarge": 3072,
52
+ # MobileNet V1
53
+ "mobilenetv1_100": 1024,
54
+ "mobilenetv1_125": 1280,
55
+ # MobileNet V2
56
+ "mobilenetv2_050": 160,
57
+ "mobilenetv2_100": 320,
58
+ "mobilenetv2_140": 448,
59
+ # MobileNet V3
60
+ "mobilenetv3_small_050": 288,
61
+ "mobilenetv3_small_075": 432,
62
+ "mobilenetv3_small_100": 576,
63
+ "mobilenetv3_large_100": 960,
64
+ # MobileNet V4
65
+ "mobilenetv4_conv_small_050": 480,
66
+ "mobilenetv4_conv_small": 960,
67
+ "mobilenetv4_conv_medium": 960,
68
+ "mobilenetv4_conv_large": 960,
69
+ }
70
+
71
+
72
+ refiner_out_channels = {
73
+ # ResNet
74
+ "resnet18": 512,
75
+ "resnet34": 512,
76
+ "resnet50": 2048,
77
+ "resnet101": 2048,
78
+ "resnet152": 2048,
79
+ # ConvNeXt
80
+ "convnext_nano": 640,
81
+ "convnext_tiny": 768,
82
+ "convnext_small": 768,
83
+ "convnext_base": 1024,
84
+ "convnext_large": 1536,
85
+ "convnext_xlarge": 2048,
86
+ "convnext_xxlarge": 3072,
87
+ # MobileNet V1
88
+ "mobilenetv1_100": 512,
89
+ "mobilenetv1_125": 640,
90
+ # MobileNet V2
91
+ "mobilenetv2_050": 160,
92
+ "mobilenetv2_100": 320,
93
+ "mobilenetv2_140": 448,
94
+ # MobileNet V3
95
+ "mobilenetv3_small_050": 288,
96
+ "mobilenetv3_small_075": 432,
97
+ "mobilenetv3_small_100": 576,
98
+ "mobilenetv3_large_100": 480,
99
+ # MobileNet V4
100
+ "mobilenetv4_conv_small_050": 480,
101
+ "mobilenetv4_conv_small": 960,
102
+ "mobilenetv4_conv_medium": 960,
103
+ "mobilenetv4_conv_large": 960,
104
+ }
105
+
106
+
107
+ groups = {
108
+ # ResNet
109
+ "resnet18": 1,
110
+ "resnet34": 1,
111
+ "resnet50": refiner_in_channels["resnet50"] // 512,
112
+ "resnet101": refiner_in_channels["resnet101"] // 512,
113
+ "resnet152": refiner_in_channels["resnet152"] // 512,
114
+ # ConvNeXt
115
+ "convnext_nano": 8,
116
+ "convnext_tiny": 8,
117
+ "convnext_small": 8,
118
+ "convnext_base": 8,
119
+ "convnext_large": refiner_in_channels["convnext_large"] // 512,
120
+ "convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512,
121
+ "convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512,
122
+ # MobileNet V1
123
+ "mobilenetv1_100": None,
124
+ "mobilenetv1_125": None,
125
+ # MobileNet V2
126
+ "mobilenetv2_050": None,
127
+ "mobilenetv2_100": None,
128
+ "mobilenetv2_140": None,
129
+ # MobileNet V3
130
+ "mobilenetv3_small_050": None,
131
+ "mobilenetv3_small_075": None,
132
+ "mobilenetv3_small_100": None,
133
+ "mobilenetv3_large_100": None,
134
+ # MobileNet V4
135
+ "mobilenetv4_conv_small_050": None,
136
+ "mobilenetv4_conv_small": None,
137
+ "mobilenetv4_conv_medium": None,
138
+ "mobilenetv4_conv_large": 1,
139
+ }
140
+
141
+
142
+ class TIMMModel(nn.Module):
143
+ def __init__(
144
+ self,
145
+ model_name: str,
146
+ block_size: Optional[int] = None,
147
+ norm: str = "none",
148
+ act: str = "none"
149
+ ) -> None:
150
+ super().__init__()
151
+ assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}"
152
+ assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
153
+ self.model_name = model_name
154
+ self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1])
155
+ self.encoder_channels = self.encoder.feature_info.channels()[-1]
156
+ self.encoder_reduction = self.encoder.feature_info.reduction()[-1]
157
+ self.block_size = block_size if block_size is not None else self.encoder_reduction
158
+
159
+ if model_name in lighter_models:
160
+ upsample_block = LighterConvUpsample
161
+ downsample_block = LighterConvDownsample
162
+ decoder_block = LighterConvRefine
163
+ elif model_name in light_models:
164
+ upsample_block = LightConvUpsample
165
+ downsample_block = LightConvDownsample
166
+ decoder_block = LightConvRefine
167
+ else:
168
+ upsample_block = partial(ConvUpsample, groups=groups[model_name])
169
+ downsample_block = partial(ConvDownsample, groups=groups[model_name])
170
+ decoder_block = partial(ConvRefine, groups=groups[model_name])
171
+
172
+
173
+ if norm == "bn":
174
+ norm_layer = nn.BatchNorm2d
175
+ elif norm == "ln":
176
+ norm_layer = nn.LayerNorm
177
+ else:
178
+ norm_layer = _get_norm_layer(self.encoder)
179
+
180
+ if act == "relu":
181
+ activation = nn.ReLU(inplace=True)
182
+ elif act == "gelu":
183
+ activation = nn.GELU()
184
+ else:
185
+ activation = _get_activation(self.encoder)
186
+
187
+ if self.block_size > self.encoder_reduction:
188
+ if self.block_size > self.encoder_reduction * 2:
189
+ assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}."
190
+ self.refiner = nn.Sequential(
191
+ downsample_block(
192
+ in_channels=self.encoder_channels,
193
+ out_channels=refiner_in_channels[self.model_name],
194
+ norm_layer=norm_layer,
195
+ activation=activation,
196
+ ),
197
+ downsample_block(
198
+ in_channels=refiner_in_channels[self.model_name],
199
+ out_channels=refiner_out_channels[self.model_name],
200
+ norm_layer=norm_layer,
201
+ activation=activation,
202
+ )
203
+ )
204
+ else:
205
+ assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}."
206
+ self.refiner = downsample_block(
207
+ in_channels=self.encoder_channels,
208
+ out_channels=refiner_out_channels[self.model_name],
209
+ norm_layer=norm_layer,
210
+ activation=activation,
211
+ )
212
+
213
+ self.refiner_channels = refiner_out_channels[self.model_name]
214
+
215
+ elif self.block_size < self.encoder_reduction:
216
+ if self.block_size < self.encoder_reduction // 2:
217
+ assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}."
218
+ self.refiner = nn.Sequential(
219
+ upsample_block(
220
+ in_channels=self.encoder_channels,
221
+ out_channels=refiner_in_channels[self.model_name],
222
+ norm_layer=norm_layer,
223
+ activation=activation,
224
+ ),
225
+ upsample_block(
226
+ in_channels=refiner_in_channels[self.model_name],
227
+ out_channels=refiner_out_channels[self.model_name],
228
+ norm_layer=norm_layer,
229
+ activation=activation,
230
+ )
231
+ )
232
+ else:
233
+ assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}."
234
+ self.refiner = upsample_block(
235
+ in_channels=self.encoder_channels,
236
+ out_channels=refiner_out_channels[self.model_name],
237
+ norm_layer=norm_layer,
238
+ activation=activation,
239
+ )
240
+
241
+ self.refiner_channels = refiner_out_channels[self.model_name]
242
+
243
+ else:
244
+ self.refiner = nn.Identity()
245
+ self.refiner_channels = self.encoder_channels
246
+
247
+ self.refiner_reduction = self.block_size
248
+
249
+ if self.refiner_channels <= 256:
250
+ self.decoder = nn.Identity()
251
+ self.decoder_channels = self.refiner_channels
252
+ elif self.refiner_channels <= 512:
253
+ self.decoder = decoder_block(
254
+ in_channels=self.refiner_channels,
255
+ out_channels=self.refiner_channels // 2,
256
+ norm_layer=norm_layer,
257
+ activation=activation,
258
+ )
259
+ self.decoder_channels = self.refiner_channels // 2
260
+ elif self.refiner_channels <= 1024:
261
+ self.decoder = nn.Sequential(
262
+ decoder_block(
263
+ in_channels=self.refiner_channels,
264
+ out_channels=self.refiner_channels // 2,
265
+ norm_layer=norm_layer,
266
+ activation=activation,
267
+ ),
268
+ decoder_block(
269
+ in_channels=self.refiner_channels // 2,
270
+ out_channels=self.refiner_channels // 4,
271
+ norm_layer=norm_layer,
272
+ activation=activation,
273
+ ),
274
+ )
275
+ self.decoder_channels = self.refiner_channels // 4
276
+ else:
277
+ self.decoder = nn.Sequential(
278
+ decoder_block(
279
+ in_channels=self.refiner_channels,
280
+ out_channels=self.refiner_channels // 2,
281
+ norm_layer=norm_layer,
282
+ activation=activation,
283
+ ),
284
+ decoder_block(
285
+ in_channels=self.refiner_channels // 2,
286
+ out_channels=self.refiner_channels // 4,
287
+ norm_layer=norm_layer,
288
+ activation=activation,
289
+ ),
290
+ decoder_block(
291
+ in_channels=self.refiner_channels // 4,
292
+ out_channels=self.refiner_channels // 8,
293
+ norm_layer=norm_layer,
294
+ activation=activation,
295
+ ),
296
+ )
297
+ self.decoder_channels = self.refiner_channels // 8
298
+
299
+ self.decoder_reduction = self.refiner_reduction
300
+
301
+ def encode(self, x: Tensor) -> Tensor:
302
+ return self.encoder(x)[0]
303
+
304
+ def refine(self, x: Tensor) -> Tensor:
305
+ return self.refiner(x)
306
+
307
+ def decode(self, x: Tensor) -> Tensor:
308
+ return self.decoder(x)
309
+
310
+ def forward(self, x: Tensor) -> Tensor:
311
+ x = self.encode(x)
312
+ x = self.refine(x)
313
+ x = self.decode(x)
314
+ return x
315
+
316
+
317
+ def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel:
318
+ return TIMMModel(model_name, block_size=block_size, norm=norm, act=act)
models/ebc/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from typing import Union, List, List
3
+
4
+
5
+ vgg_urls = {
6
+ "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
7
+ "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
8
+ "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
9
+ "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
10
+ "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
11
+ "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
12
+ "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
13
+ "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
14
+ }
15
+
16
+
17
+ vgg_cfgs = {
18
+ "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
19
+ "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
20
+ "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512],
21
+ "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512]
22
+ }
23
+
24
+
25
+ def make_vgg_layers(cfg: List[Union[str, int]], in_channels: int = 3, batch_norm: bool = False, dilation: int = 1) -> nn.Sequential:
26
+ layers = []
27
+ for v in cfg:
28
+ if v == "M":
29
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
30
+ else:
31
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=dilation, dilation=dilation)
32
+ if batch_norm:
33
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
34
+ else:
35
+ layers += [conv2d, nn.ReLU(inplace=True)]
36
+ in_channels = v
37
+ return nn.Sequential(*layers)
models/ebc/vgg.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ from torch.hub import load_state_dict_from_url
3
+ from typing import Optional
4
+
5
+ from .utils import make_vgg_layers, vgg_cfgs, vgg_urls
6
+ from ..utils import _init_weights, _get_norm_layer, _get_activation
7
+ from ..utils import ConvDownsample, ConvUpsample
8
+
9
+
10
+ vgg_models = [
11
+ "vgg11", "vgg11_bn",
12
+ "vgg13", "vgg13_bn",
13
+ "vgg16", "vgg16_bn",
14
+ "vgg19", "vgg19_bn",
15
+ ]
16
+
17
+ decoder_cfg = [512, 256, 128]
18
+
19
+
20
+ class VGGEncoder(nn.Module):
21
+ def __init__(
22
+ self,
23
+ model_name: str,
24
+ block_size: Optional[int] = None,
25
+ norm: str = "none",
26
+ act: str = "none",
27
+ ) -> None:
28
+ super().__init__()
29
+ assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
30
+ assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
31
+ self.model_name = model_name
32
+
33
+ if model_name == "vgg11":
34
+ self.encoder = vgg11()
35
+ elif model_name == "vgg11_bn":
36
+ self.encoder = vgg11_bn()
37
+ elif model_name == "vgg13":
38
+ self.encoder = vgg13()
39
+ elif model_name == "vgg13_bn":
40
+ self.encoder = vgg13_bn()
41
+ elif model_name == "vgg16":
42
+ self.encoder = vgg16()
43
+ elif model_name == "vgg16_bn":
44
+ self.encoder = vgg16_bn()
45
+ elif model_name == "vgg19":
46
+ self.encoder = vgg19()
47
+ else: # model_name == "vgg19_bn"
48
+ self.encoder = vgg19_bn()
49
+
50
+ self.encoder_channels = 512
51
+ self.encoder_reduction = 16
52
+ self.block_size = block_size if block_size is not None else self.encoder_reduction
53
+
54
+ if norm == "bn":
55
+ norm_layer = nn.BatchNorm2d
56
+ elif norm == "ln":
57
+ norm_layer = nn.LayerNorm
58
+ else:
59
+ norm_layer = _get_norm_layer(self.encoder)
60
+
61
+ if act == "relu":
62
+ activation = nn.ReLU(inplace=True)
63
+ elif act == "gelu":
64
+ activation = nn.GELU()
65
+ else:
66
+ activation = _get_activation(self.encoder)
67
+
68
+ if self.encoder_reduction >= self.block_size: # 8, 16
69
+ self.refiner = ConvUpsample(
70
+ in_channels=self.encoder_channels,
71
+ out_channels=self.encoder_channels,
72
+ scale_factor=self.encoder_reduction // self.block_size,
73
+ norm_layer=norm_layer,
74
+ activation=activation,
75
+ )
76
+ else: # 32
77
+ self.refiner = ConvDownsample(
78
+ in_channels=self.encoder_channels,
79
+ out_channels=self.encoder_channels,
80
+ norm_layer=norm_layer,
81
+ activation=activation,
82
+ )
83
+ self.refiner_channels = self.encoder_channels
84
+ self.refiner_reduction = self.block_size
85
+
86
+ self.decoder = nn.Identity()
87
+ self.decoder_channels = self.encoder_channels
88
+ self.decoder_reduction = self.refiner_reduction
89
+
90
+ def encode(self, x: Tensor) -> Tensor:
91
+ return self.encoder(x)
92
+
93
+ def refine(self, x: Tensor) -> Tensor:
94
+ return self.refiner(x)
95
+
96
+ def decode(self, x: Tensor) -> Tensor:
97
+ return self.decoder(x)
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ x = self.encode(x)
101
+ x = self.refine(x)
102
+ x = self.decode(x)
103
+ return x
104
+
105
+
106
+ class VGGEncoderDecoder(nn.Module):
107
+ def __init__(
108
+ self,
109
+ model_name: str,
110
+ block_size: Optional[int] = None,
111
+ norm: str = "none",
112
+ act: str = "none",
113
+ ) -> None:
114
+ super().__init__()
115
+ assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
116
+ assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
117
+ self.model_name = model_name
118
+
119
+ if model_name == "vgg11":
120
+ encoder = vgg11()
121
+ elif model_name == "vgg11_bn":
122
+ encoder = vgg11_bn()
123
+ elif model_name == "vgg13":
124
+ encoder = vgg13()
125
+ elif model_name == "vgg13_bn":
126
+ encoder = vgg13_bn()
127
+ elif model_name == "vgg16":
128
+ encoder = vgg16()
129
+ elif model_name == "vgg16_bn":
130
+ encoder = vgg16_bn()
131
+ elif model_name == "vgg19":
132
+ encoder = vgg19()
133
+ else: # model_name == "vgg19_bn"
134
+ encoder = vgg19_bn()
135
+
136
+ encoder_channels = 512
137
+ encoder_reduction = 16
138
+ decoder = make_vgg_layers(decoder_cfg, in_channels=encoder_channels, batch_norm="bn" in model_name, dilation=1)
139
+ decoder.apply(_init_weights)
140
+
141
+ if norm == "bn":
142
+ norm_layer = nn.BatchNorm2d
143
+ elif norm == "ln":
144
+ norm_layer = nn.LayerNorm
145
+ else:
146
+ norm_layer = _get_norm_layer(encoder)
147
+
148
+ if act == "relu":
149
+ activation = nn.ReLU(inplace=True)
150
+ elif act == "gelu":
151
+ activation = nn.GELU()
152
+ else:
153
+ activation = _get_activation(encoder)
154
+
155
+ self.encoder = nn.Sequential(encoder, decoder)
156
+ self.encoder_channels = decoder_cfg[-1]
157
+ self.encoder_reduction = encoder_reduction
158
+ self.block_size = block_size if block_size is not None else self.encoder_reduction
159
+
160
+ if self.encoder_reduction >= self.block_size:
161
+ self.refiner = ConvUpsample(
162
+ in_channels=self.encoder_channels,
163
+ out_channels=self.encoder_channels,
164
+ scale_factor=self.encoder_reduction // self.block_size,
165
+ norm_layer=norm_layer,
166
+ activation=activation,
167
+ )
168
+ else:
169
+ self.refiner = ConvDownsample(
170
+ in_channels=self.encoder_channels,
171
+ out_channels=self.encoder_channels,
172
+ norm_layer=norm_layer,
173
+ activation=activation,
174
+ )
175
+ self.refiner_channels = self.encoder_channels
176
+ self.refiner_reduction = self.block_size
177
+
178
+ self.decoder = nn.Identity()
179
+ self.decoder_channels = self.refiner_channels
180
+ self.decoder_reduction = self.refiner_reduction
181
+
182
+ def encode(self, x: Tensor) -> Tensor:
183
+ return self.encoder(x)
184
+
185
+ def refine(self, x: Tensor) -> Tensor:
186
+ return self.refiner(x)
187
+
188
+ def decode(self, x: Tensor) -> Tensor:
189
+ return self.decoder(x)
190
+
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ x = self.encode(x)
193
+ x = self.refine(x)
194
+ x = self.decode(x)
195
+ return x
196
+
197
+
198
+ class VGG(nn.Module):
199
+ def __init__(
200
+ self,
201
+ features: nn.Module,
202
+ ) -> None:
203
+ super().__init__()
204
+ self.features = features
205
+
206
+ def forward(self, x: Tensor) -> Tensor:
207
+ x = self.features(x)
208
+ return x
209
+
210
+
211
+ def vgg11() -> VGG:
212
+ model = VGG(make_vgg_layers(vgg_cfgs["A"]))
213
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11"]), strict=False)
214
+ return model
215
+
216
+ def vgg11_bn() -> VGG:
217
+ model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True))
218
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11_bn"]), strict=False)
219
+ return model
220
+
221
+ def vgg13() -> VGG:
222
+ model = VGG(make_vgg_layers(vgg_cfgs["B"]))
223
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13"]), strict=False)
224
+ return model
225
+
226
+ def vgg13_bn() -> VGG:
227
+ model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True))
228
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13_bn"]), strict=False)
229
+ return model
230
+
231
+ def vgg16() -> VGG:
232
+ model = VGG(make_vgg_layers(vgg_cfgs["D"]))
233
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16"]), strict=False)
234
+ return model
235
+
236
+ def vgg16_bn() -> VGG:
237
+ model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True))
238
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16_bn"]), strict=False)
239
+ return model
240
+
241
+ def vgg19() -> VGG:
242
+ model = VGG(make_vgg_layers(vgg_cfgs["E"]))
243
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19"]), strict=False)
244
+ return model
245
+
246
+ def vgg19_bn() -> VGG:
247
+ model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True))
248
+ model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19_bn"]), strict=False)
249
+ return model
250
+
251
+ def _vgg_encoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoder:
252
+ return VGGEncoder(model_name, block_size, norm=norm, act=act)
253
+
254
+ def _vgg_encoder_decoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoderDecoder:
255
+ return VGGEncoderDecoder(model_name, block_size, norm=norm, act=act)
models/ebc/vit.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import timm
4
+ from einops import rearrange
5
+ import torch.nn.functional as F
6
+
7
+ import math
8
+ from typing import Optional, Tuple
9
+ from ..utils import ConvUpsample, ConvDownsample, _get_activation, _get_norm_layer, ConvRefine
10
+
11
+
12
+ supported_vit_backbones = [
13
+ # Tiny
14
+ "vit_tiny_patch16_224", "vit_tiny_patch16_384",
15
+ # Small
16
+ "vit_small_patch8_224",
17
+ "vit_small_patch16_224", "vit_small_patch16_384",
18
+ "vit_small_patch32_224", "vit_small_patch32_384",
19
+ # Base
20
+ "vit_base_patch8_224",
21
+ "vit_base_patch16_224", "vit_base_patch16_384",
22
+ "vit_base_patch32_224", "vit_base_patch32_384",
23
+ # Large
24
+ "vit_large_patch16_224", "vit_large_patch16_384",
25
+ "vit_large_patch32_224", "vit_large_patch32_384",
26
+ # Huge
27
+ "vit_huge_patch14_224",
28
+ ]
29
+
30
+
31
+ refiner_channels = {
32
+ "vit_tiny_patch16_224": 192,
33
+ "vit_tiny_patch16_384": 192,
34
+ "vit_small_patch8_224": 384,
35
+ "vit_small_patch16_224": 384,
36
+ "vit_small_patch16_384": 384,
37
+ "vit_small_patch32_224": 384,
38
+ "vit_small_patch32_384": 384,
39
+ "vit_base_patch8_224": 768,
40
+ "vit_base_patch16_224": 768,
41
+ "vit_base_patch16_384": 768,
42
+ "vit_base_patch32_224": 768,
43
+ "vit_base_patch32_384": 768,
44
+ "vit_large_patch16_224": 1024,
45
+ "vit_large_patch16_384": 1024,
46
+ "vit_large_patch32_224": 1024,
47
+ "vit_large_patch32_384": 1024,
48
+ }
49
+
50
+ refiner_groups = {
51
+ "vit_tiny_patch16_224": 1,
52
+ "vit_tiny_patch16_384": 1,
53
+ "vit_small_patch8_224": 1,
54
+ "vit_small_patch16_224": 1,
55
+ "vit_small_patch16_384": 1,
56
+ "vit_small_patch32_224": 1,
57
+ "vit_small_patch32_384": 1,
58
+ "vit_base_patch8_224": 1,
59
+ "vit_base_patch16_224": 1,
60
+ "vit_base_patch16_384": 1,
61
+ "vit_base_patch32_224": 1,
62
+ "vit_base_patch32_384": 1,
63
+ "vit_large_patch16_224": 1,
64
+ "vit_large_patch16_384": 1,
65
+ "vit_large_patch32_224": 1,
66
+ "vit_large_patch32_384": 1,
67
+ }
68
+
69
+
70
+ class ViT(nn.Module):
71
+ def __init__(
72
+ self,
73
+ model_name: str,
74
+ block_size: Optional[int] = None,
75
+ num_vpt: int = 32,
76
+ vpt_drop: float = 0.0,
77
+ input_size: Optional[Tuple[int, int]] = None,
78
+ norm: str = "none",
79
+ act: str = "none"
80
+ ) -> None:
81
+ super().__init__()
82
+ assert model_name in supported_vit_backbones, f"Model {model_name} not supported"
83
+ assert num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
84
+ self.model_name = model_name
85
+
86
+ self.num_vpt = num_vpt
87
+ self.vpt_drop = vpt_drop
88
+
89
+ model = timm.create_model(model_name, pretrained=True)
90
+
91
+ self.input_size = input_size if input_size is not None else model.patch_embed.img_size
92
+ self.pretrain_size = model.patch_embed.img_size
93
+ self.patch_size = model.patch_embed.patch_size
94
+
95
+ if self.patch_size[0] in [8, 16, 32]:
96
+ assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
97
+ else: # patch_size == 14
98
+ assert block_size is None or block_size in [7, 14, 28], f"Block size should be one of [7, 14, 28], but got {block_size}."
99
+
100
+ self.num_layers = len(model.blocks)
101
+ self.embed_dim = model.cls_token.shape[-1]
102
+
103
+ if self.num_vpt > 0: # Use visual prompt tuning so freeze the backbone
104
+ for param in model.parameters():
105
+ param.requires_grad = False
106
+
107
+ # Setup VPT tokens
108
+ val = math.sqrt(6. / float(3 * self.patch_size[0] + self.embed_dim))
109
+ for idx in range(self.num_layers):
110
+ setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.embed_dim)))
111
+ nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
112
+ setattr(self, f"vpt_drop_{idx}", nn.Dropout(self.vpt_drop))
113
+
114
+ self.patch_embed = model.patch_embed
115
+ self.cls_token = model.cls_token
116
+ self.pos_embed = model.pos_embed
117
+ self.pos_drop = model.pos_drop
118
+ self.patch_drop = model.patch_drop
119
+ self.norm_pre = model.norm_pre
120
+
121
+ self.blocks = model.blocks
122
+ self.norm = model.norm
123
+
124
+ self.encoder_channels = self.embed_dim
125
+ self.encoder_reduction = self.patch_size[0]
126
+ self.block_size = block_size if block_size is not None else self.encoder_reduction
127
+
128
+ if norm == "bn":
129
+ norm_layer = nn.BatchNorm2d
130
+ elif norm == "ln":
131
+ norm_layer = nn.LayerNorm
132
+ else:
133
+ norm_layer = _get_norm_layer(model)
134
+
135
+ if act == "relu":
136
+ activation = nn.ReLU(inplace=True)
137
+ elif act == "gelu":
138
+ activation = nn.GELU()
139
+ else:
140
+ activation = _get_activation(model)
141
+
142
+ if self.block_size < self.encoder_reduction:
143
+ assert self.block_size == self.encoder_reduction // 2, f"Block size should be half of the encoder reduction, but got {self.block_size} and {self.encoder_reduction}."
144
+ self.refiner = ConvUpsample(
145
+ in_channels=self.encoder_channels,
146
+ out_channels=self.encoder_channels,
147
+ norm_layer=norm_layer,
148
+ activation=activation,
149
+ )
150
+ elif self.block_size > self.encoder_reduction:
151
+ assert self.block_size == self.encoder_reduction * 2, f"Block size should be double of the encoder reduction, but got {self.block_size} and {self.encoder_reduction}."
152
+ self.refiner = ConvDownsample(
153
+ in_channels=self.encoder_channels,
154
+ out_channels=self.encoder_channels,
155
+ norm_layer=norm_layer,
156
+ activation=activation,
157
+ )
158
+ else:
159
+ self.refiner = ConvRefine(
160
+ in_channels=self.encoder_channels,
161
+ out_channels=self.encoder_channels,
162
+ norm_layer=norm_layer,
163
+ activation=activation,
164
+ )
165
+
166
+ self.refiner_channels = self.encoder_channels
167
+ self.refiner_reduction = self.block_size
168
+
169
+ self.decoder = nn.Identity()
170
+ self.decoder_channels = self.refiner_channels
171
+ self.reduction = self.refiner_reduction
172
+
173
+ # Adjust the positional embedding to match the new input size
174
+ self._adjust_pos_embed()
175
+
176
+ def _adjust_pos_embed(self) -> Tensor:
177
+ """
178
+ Adjust the positional embedding to match the spatial resolution of the feature map.
179
+
180
+ Args:
181
+ orig_h, orig_w: The original spatial resolution of the image.
182
+ new_h, new_w: The new spatial resolution of the image.
183
+ """
184
+ self.pos_embed = nn.Parameter(self._interpolate_pos_embed(self.pretrain_size[0], self.pretrain_size[1], self.input_size[0], self.input_size[1]), requires_grad=self.num_vpt == 0)
185
+
186
+ def _interpolate_pos_embed(self, orig_h: int, orig_w: int, new_h: int, new_w: int) -> Tensor:
187
+ """
188
+ Interpolate the positional embedding to match the spatial resolution of the feature map.
189
+
190
+ Args:
191
+ orig_h, orig_w: The original spatial resolution of the image.
192
+ new_h, new_w: The new spatial resolution of the image.
193
+ """
194
+ if (orig_h, orig_w) == (new_h, new_w):
195
+ return self.pos_embed # (1, (h * w + 1), d)
196
+
197
+ orig_h_patches, orig_w_patches = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
198
+ new_h_patches, new_w_patches = new_h // self.patch_size[0], new_w // self.patch_size[1]
199
+ class_pos_embed, patch_pos_embed = self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :]
200
+ patch_pos_embed = rearrange(patch_pos_embed, "1 (h w) d -> 1 d h w", h=orig_h_patches, w=orig_w_patches)
201
+ patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_h_patches, new_w_patches), mode="bicubic", antialias=True)
202
+ patch_pos_embed = rearrange(patch_pos_embed, "1 d h w -> 1 (h w) d")
203
+ pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=1)
204
+ return pos_embed
205
+
206
+ def train(self, mode: bool = True):
207
+ if self.num_vpt > 0 and mode:
208
+ self.patch_embed.eval()
209
+ self.pos_drop.eval()
210
+ self.patch_drop.eval()
211
+ self.norm_pre.eval()
212
+
213
+ self.blocks.eval()
214
+ self.norm.eval()
215
+
216
+ for idx in range(self.num_layers):
217
+ getattr(self, f"vpt_drop_{idx}").train()
218
+
219
+ self.refiner.train()
220
+ self.decoder.train()
221
+
222
+ else:
223
+ for module in self.children():
224
+ module.train(mode)
225
+
226
+ def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
227
+ vpt = getattr(self, f"vpt_{layer}").unsqueeze(0).expand(batch_size, -1, -1).to(device) # (batch_size, num_vpt, embed_dim)
228
+ vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
229
+
230
+ return vpt
231
+
232
+ def _forward_patch_embed(self, x: Tensor) -> Tensor:
233
+ # This step performs 1) embed x into patches; 2) append the class token; 3) add positional embeddings.
234
+ assert len(x.shape) == 4, f"Expected input to have shape (batch_size, 3, height, width), but got {x.shape}"
235
+ batch_size, _, height, width = x.shape
236
+
237
+ # Step 1: Embed x into patches
238
+ x = self.patch_embed(x) # (b, h * w, d)
239
+
240
+ # Step 2: Append the class token
241
+ cls_token = self.cls_token.expand(batch_size, 1, -1)
242
+ x = torch.cat([cls_token, x], dim=1)
243
+
244
+ # Step 3: Add positional embeddings
245
+ pos_embed = self._interpolate_pos_embed(orig_h=self.input_size[0], orig_w=self.input_size[1], new_h=height, new_w=width).expand(batch_size, -1, -1)
246
+ x = self.pos_drop(x + pos_embed)
247
+ return x
248
+
249
+ def _forward_vpt(self, x: Tensor, idx: int) -> Tensor:
250
+ batch_size = x.shape[0]
251
+ device = x.device
252
+
253
+ # Assemble
254
+ vpt = self._prepare_vpt(idx, batch_size, device)
255
+ x = torch.cat([
256
+ x[:, :1, :], # class token
257
+ vpt,
258
+ x[:, 1:, :] # patches
259
+ ], dim=1)
260
+
261
+ # Forward
262
+ x = self.blocks[idx](x)
263
+
264
+ # Disassemble
265
+ x = torch.cat([
266
+ x[:, :1, :], # class token
267
+ x[:, 1 + self.num_vpt:, :] # patches
268
+ ], dim=1)
269
+
270
+ return x
271
+
272
+ def _forward(self, x: Tensor, idx: int) -> Tensor:
273
+ x = self.blocks[idx](x)
274
+ return x
275
+
276
+ def encode(self, x: Tensor) -> Tensor:
277
+ orig_h, orig_w = x.shape[-2:]
278
+ num_patches_h, num_patches_w = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
279
+
280
+ x = self._forward_patch_embed(x)
281
+ x = self.patch_drop(x)
282
+ x = self.norm_pre(x)
283
+
284
+ for idx in range(self.num_layers):
285
+ x = self._forward_vpt(x, idx) if self.num_vpt > 0 else self._forward(x, idx)
286
+
287
+ x = self.norm(x)
288
+ x = x[:, 1:, :]
289
+ x = rearrange(x, "b (h w) d -> b d h w", h=num_patches_h, w=num_patches_w)
290
+ return x
291
+
292
+ def refine(self, x: Tensor) -> Tensor:
293
+ return self.refiner(x)
294
+
295
+ def decode(self, x: Tensor) -> Tensor:
296
+ return self.decoder(x)
297
+
298
+ def forward(self, x: Tensor) -> Tensor:
299
+ x = self.encode(x)
300
+ x = self.refine(x)
301
+ x = self.decode(x)
302
+ return x
303
+
304
+
305
+ def _vit(
306
+ model_name: str,
307
+ block_size: Optional[int] = None,
308
+ num_vpt: int = 32,
309
+ vpt_drop: float = 0.0,
310
+ input_size: Optional[Tuple[int, int]] = None,
311
+ norm: str = "none",
312
+ act: str = "none"
313
+ ) -> ViT:
314
+ model = ViT(
315
+ model_name=model_name,
316
+ block_size=block_size,
317
+ num_vpt=num_vpt,
318
+ vpt_drop=vpt_drop,
319
+ input_size=input_size,
320
+ norm=norm,
321
+ act=act
322
+ )
323
+ return model
models/utils/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from typing import Optional
3
+ from functools import partial
4
+
5
+ from .utils import _init_weights, interpolate_pos_embed
6
+ from .blocks import DepthSeparableConv2d, conv1x1, conv3x3, Conv2dLayerNorm
7
+ from .refine import ConvRefine, LightConvRefine, LighterConvRefine
8
+ from .downsample import ConvDownsample, LightConvDownsample, LighterConvDownsample
9
+ from .upsample import ConvUpsample, LightConvUpsample, LighterConvUpsample
10
+ from .multi_scale import MultiScale
11
+ from .blocks import ConvAdapter, ViTAdapter
12
+
13
+
14
+ def _get_norm_layer(model: nn.Module) -> Optional[nn.Module]:
15
+ for module in model.modules():
16
+ if isinstance(module, nn.BatchNorm2d):
17
+ return nn.BatchNorm2d
18
+ elif isinstance(module, nn.GroupNorm):
19
+ num_groups = module.num_groups
20
+ return partial(nn.GroupNorm, num_groups=num_groups)
21
+ elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
22
+ return Conv2dLayerNorm
23
+ return None
24
+
25
+
26
+ def _get_activation(model: nn.Module) -> Optional[nn.Module]:
27
+ for module in model.modules():
28
+ if isinstance(module, nn.BatchNorm2d):
29
+ return nn.ReLU(inplace=True)
30
+ elif isinstance(module, nn.GroupNorm):
31
+ return nn.ReLU(inplace=True)
32
+ elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
33
+ return nn.GELU()
34
+ return nn.GELU()
35
+
36
+
37
+
38
+ __all__ = [
39
+ "_init_weights", "_check_norm_layer", "_check_activation",
40
+ "conv1x1",
41
+ "conv3x3",
42
+ "Conv2dLayerNorm",
43
+ "interpolate_pos_embed",
44
+ "DepthSeparableConv2d",
45
+ "ConvRefine",
46
+ "LightConvRefine",
47
+ "LighterConvRefine",
48
+ "ConvDownsample",
49
+ "LightConvDownsample",
50
+ "LighterConvDownsample",
51
+ "ConvUpsample",
52
+ "LightConvUpsample",
53
+ "LighterConvUpsample",
54
+ "MultiScale",
55
+ "ConvAdapter", "ViTAdapter",
56
+ ]
models/utils/blocks.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from einops.layers.torch import Rearrange
6
+
7
+ from typing import Callable, Optional, Sequence, Tuple, Union, List, List
8
+ import warnings
9
+
10
+ from .utils import _init_weights, _make_ntuple, _log_api_usage_once
11
+
12
+
13
+ def conv3x3(
14
+ in_channels: int,
15
+ out_channels: int,
16
+ stride: int = 1,
17
+ groups: int = 1,
18
+ dilation: int = 1,
19
+ bias: bool = True,
20
+ ) -> nn.Conv2d:
21
+ """3x3 convolution with padding"""
22
+ conv = nn.Conv2d(
23
+ in_channels,
24
+ out_channels,
25
+ kernel_size=3,
26
+ stride=stride,
27
+ padding=dilation,
28
+ groups=groups,
29
+ bias=bias,
30
+ dilation=dilation,
31
+ )
32
+ conv.apply(_init_weights)
33
+ return conv
34
+
35
+
36
+ def conv1x1(
37
+ in_channels: int,
38
+ out_channels: int,
39
+ stride: int = 1,
40
+ bias: bool = True,
41
+ ) -> nn.Conv2d:
42
+ """1x1 convolution"""
43
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=bias)
44
+ conv.apply(_init_weights)
45
+ return conv
46
+
47
+
48
+ class DepthSeparableConv2d(nn.Module):
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ out_channels: int,
53
+ kernel_size: int,
54
+ stride: int = 1,
55
+ padding: int = 0,
56
+ dilation: int = 1,
57
+ bias: bool = True,
58
+ padding_mode: str = "zeros",
59
+ ) -> None:
60
+ super().__init__()
61
+ # Depthwise convolution: one filter per input channel.
62
+ self.depthwise = nn.Conv2d(
63
+ in_channels=in_channels,
64
+ out_channels=in_channels,
65
+ kernel_size=kernel_size,
66
+ stride=stride,
67
+ padding=padding,
68
+ dilation=dilation,
69
+ groups=in_channels,
70
+ bias=bias,
71
+ padding_mode=padding_mode
72
+ )
73
+ # Pointwise convolution: combine the features across channels.
74
+ self.pointwise = nn.Conv2d(
75
+ in_channels=in_channels,
76
+ out_channels=out_channels,
77
+ kernel_size=1,
78
+ stride=1,
79
+ padding=0,
80
+ dilation=1,
81
+ groups=1,
82
+ bias=bias,
83
+ padding_mode=padding_mode
84
+ )
85
+ self.apply(_init_weights)
86
+
87
+ def forward(self, x: Tensor) -> Tensor:
88
+ return self.pointwise(self.depthwise(x))
89
+
90
+
91
+ class SEBlock(nn.Module):
92
+ def __init__(self, channels: int, reduction: int = 16):
93
+ super().__init__()
94
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
95
+ self.fc = nn.Sequential(
96
+ nn.Linear(channels, channels // reduction, bias=False),
97
+ nn.ReLU(inplace=True),
98
+ nn.Linear(channels // reduction, channels, bias=False),
99
+ nn.Sigmoid()
100
+ )
101
+ self.apply(_init_weights)
102
+
103
+ def forward(self, x: Tensor) -> Tensor:
104
+ B, C, _, _ = x.shape
105
+ y = self.avg_pool(x).view(B, C)
106
+ y = self.fc(y).view(B, C, 1, 1)
107
+ return x * y
108
+
109
+
110
+ class BasicBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ in_channels: int,
114
+ out_channels: int,
115
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
116
+ activation: nn.Module = nn.ReLU(inplace=True),
117
+ groups: int = 1,
118
+ ) -> None:
119
+ super().__init__()
120
+ assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
121
+ assert in_channels % groups == 0, f"Expected in_channels to be divisible by groups, but got {in_channels} % {groups}"
122
+ assert out_channels % groups == 0, f"Expected out_channels to be divisible by groups, but got {out_channels} % {groups}"
123
+ self.grouped_conv = groups > 1
124
+ self.conv1 = conv3x3(
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ stride=1,
128
+ bias=not norm_layer,
129
+ groups=groups,
130
+ )
131
+ if self.grouped_conv:
132
+ self.conv1_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
133
+
134
+ self.norm1 = norm_layer(out_channels) if norm_layer else nn.Identity()
135
+ self.act1 = activation
136
+
137
+ self.conv2 = conv3x3(
138
+ in_channels=out_channels,
139
+ out_channels=out_channels,
140
+ stride=1,
141
+ bias=not norm_layer,
142
+ groups=groups,
143
+ )
144
+ if self.grouped_conv:
145
+ self.conv2_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
146
+
147
+ self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
148
+ self.act2 = activation
149
+
150
+ if in_channels != out_channels:
151
+ self.downsample = nn.Sequential(
152
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
153
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
154
+ )
155
+ else:
156
+ self.downsample = nn.Identity()
157
+
158
+ self.apply(_init_weights)
159
+
160
+ def forward(self, x: Tensor) -> Tensor:
161
+ identity = x
162
+
163
+ out = self.conv1(x)
164
+ out = self.conv1_1x1(out) if self.grouped_conv else out
165
+ out = self.norm1(out)
166
+ out = self.act1(out)
167
+
168
+ out = self.conv2(out)
169
+ out = self.conv2_1x1(out) if self.grouped_conv else out
170
+ out = self.norm2(out)
171
+
172
+ out += self.downsample(identity)
173
+ out = self.act2(out)
174
+
175
+ return out
176
+
177
+
178
+ class LightBasicBlock(nn.Module):
179
+ def __init__(
180
+ self,
181
+ in_channels: int,
182
+ out_channels: int,
183
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
184
+ activation: nn.Module = nn.ReLU(inplace=True),
185
+ ) -> None:
186
+ super().__init__()
187
+ self.conv1 = DepthSeparableConv2d(
188
+ in_channels=in_channels,
189
+ out_channels=out_channels,
190
+ kernel_size=3,
191
+ stride=1,
192
+ padding=1,
193
+ bias=not norm_layer,
194
+ )
195
+ self.norm1 = norm_layer(out_channels) if norm_layer else nn.Identity()
196
+ self.act1 = activation
197
+
198
+ self.conv2 = DepthSeparableConv2d(
199
+ in_channels=out_channels,
200
+ out_channels=out_channels,
201
+ kernel_size=3,
202
+ stride=1,
203
+ padding=1,
204
+ bias=not norm_layer,
205
+ )
206
+ self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
207
+ self.act2 = activation
208
+
209
+ if in_channels != out_channels:
210
+ self.downsample = nn.Sequential(
211
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
212
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
213
+ )
214
+ else:
215
+ self.downsample = nn.Identity()
216
+
217
+ self.apply(_init_weights)
218
+
219
+ def forward(self, x: Tensor) -> Tensor:
220
+ identity = x
221
+
222
+ out = self.conv1(x)
223
+ out = self.norm1(out)
224
+ out = self.act1(out)
225
+
226
+ out = self.conv2(out)
227
+ out = self.norm2(out)
228
+
229
+ out += self.downsample(identity)
230
+ out = self.act2(out)
231
+
232
+ return out
233
+
234
+
235
+ class Bottleneck(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels: int,
239
+ out_channels: int,
240
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
241
+ activation: nn.Module = nn.ReLU(inplace=True),
242
+ groups: int = 1,
243
+ base_width: int = 64,
244
+ expansion: float = 2.0,
245
+ ) -> None:
246
+ super().__init__()
247
+ assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
248
+ assert expansion > 0, f"Expected expansion to be greater than 0, but got {expansion}"
249
+ assert base_width > 0, f"Expected base_width to be greater than 0, but got {base_width}"
250
+ bottleneck_channels = int(in_channels * (base_width / 64.0) * expansion)
251
+ assert bottleneck_channels % groups == 0, f"Expected bottleneck_channels to be divisible by groups, but got {bottleneck_channels} % {groups}"
252
+ self.grouped_conv = groups > 1
253
+ self.expansion, self.base_width = expansion, base_width
254
+
255
+ self.conv_in = conv1x1(in_channels, bottleneck_channels, stride=1, bias=not norm_layer)
256
+ self.norm_in = norm_layer(bottleneck_channels)
257
+ self.act_in = activation
258
+
259
+ self.se_in = SEBlock(bottleneck_channels) if bottleneck_channels > in_channels else nn.Identity()
260
+
261
+ self.conv_block_1 = nn.Sequential(
262
+ conv3x3(
263
+ in_channels=bottleneck_channels,
264
+ out_channels=bottleneck_channels,
265
+ stride=1,
266
+ groups=groups,
267
+ bias=not norm_layer
268
+ ),
269
+ conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
270
+ norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
271
+ activation,
272
+ )
273
+
274
+ self.conv_block_2 = nn.Sequential(
275
+ conv3x3(
276
+ in_channels=bottleneck_channels,
277
+ out_channels=bottleneck_channels,
278
+ stride=1,
279
+ groups=groups,
280
+ bias=not norm_layer
281
+ ),
282
+ conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
283
+ norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
284
+ activation,
285
+ )
286
+
287
+ self.conv_out = conv1x1(bottleneck_channels, out_channels, stride=1, bias=not norm_layer)
288
+ self.norm_out = norm_layer(out_channels)
289
+ self.act_out = activation
290
+ self.se_out = SEBlock(out_channels) if out_channels > bottleneck_channels else nn.Identity()
291
+
292
+ if in_channels != out_channels:
293
+ self.downsample = nn.Sequential(
294
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
295
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
296
+ )
297
+ else:
298
+ self.downsample = nn.Identity()
299
+
300
+ self.apply(_init_weights)
301
+
302
+ def forward(self, x: Tensor) -> Tensor:
303
+ identity = x
304
+
305
+ # expand
306
+ out = self.conv_in(x)
307
+ out = self.norm_in(out)
308
+ out = self.act_in(out)
309
+ out = self.se_in(out)
310
+
311
+ # conv
312
+ out = self.conv_block_1(out)
313
+ out = self.conv_block_2(out)
314
+
315
+ # reduce
316
+ out = self.conv_out(out)
317
+ out = self.norm_out(out)
318
+ out = self.se_out(out)
319
+
320
+ out += self.downsample(identity)
321
+ out = self.act_out(out)
322
+ return out
323
+
324
+
325
+ class ConvASPP(nn.Module):
326
+ def __init__(
327
+ self,
328
+ in_channels: int,
329
+ out_channels: int,
330
+ dilations: List[int] = [1, 2, 4],
331
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
332
+ activation: nn.Module = nn.ReLU(inplace=True),
333
+ groups: int = 1,
334
+ base_width: int = 64,
335
+ expansion: float = 2.0,
336
+ ) -> None:
337
+ super().__init__()
338
+ assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
339
+ assert expansion > 0, f"Expected expansion to be greater than 0, but got {expansion}"
340
+ assert base_width > 0, f"Expected base_width to be greater than 0, but got {base_width}"
341
+ bottleneck_channels = int(in_channels * (base_width / 64.0) * expansion)
342
+ assert bottleneck_channels % groups == 0, f"Expected bottleneck_channels to be divisible by groups, but got {bottleneck_channels} % {groups}"
343
+ self.expansion, self.base_width = expansion, base_width
344
+
345
+ self.conv_in = conv1x1(in_channels, bottleneck_channels, stride=1, bias=not norm_layer)
346
+ self.norm_in = norm_layer(bottleneck_channels)
347
+ self.act_in = activation
348
+
349
+ conv_blocks = [nn.Sequential(
350
+ conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer),
351
+ norm_layer(bottleneck_channels),
352
+ activation
353
+ )]
354
+
355
+ for dilation in dilations:
356
+ conv_blocks.append(nn.Sequential(
357
+ conv3x3(
358
+ in_channels=bottleneck_channels,
359
+ out_channels=bottleneck_channels,
360
+ stride=1,
361
+ groups=groups,
362
+ dilation=dilation,
363
+ bias=not norm_layer
364
+ ),
365
+ conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
366
+ norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
367
+ activation
368
+ ))
369
+
370
+ self.convs = nn.ModuleList(conv_blocks)
371
+
372
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
373
+ self.conv_avg = conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer)
374
+ self.norm_avg = norm_layer(bottleneck_channels)
375
+ self.act_avg = activation
376
+
377
+ self.se = SEBlock(bottleneck_channels * (len(dilations) + 2))
378
+
379
+ self.conv_out = conv1x1(bottleneck_channels * (len(dilations) + 2), out_channels, stride=1, bias=not norm_layer)
380
+ self.norm_out = norm_layer(out_channels)
381
+ self.act_out = activation
382
+
383
+ if in_channels != out_channels:
384
+ self.downsample = nn.Sequential(
385
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
386
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
387
+ )
388
+ else:
389
+ self.downsample = nn.Identity()
390
+
391
+ self.apply(_init_weights)
392
+
393
+ def forward(self, x: Tensor) -> Tensor:
394
+ height, width = x.shape[-2:]
395
+ identity = x
396
+
397
+ # expand
398
+ out = self.conv_in(x)
399
+ out = self.norm_in(out)
400
+ out = self.act_in(out)
401
+
402
+ outs = []
403
+ for conv in self.convs:
404
+ outs.append(conv(out))
405
+
406
+ avg = self.avgpool(out)
407
+ avg = self.conv_avg(avg)
408
+ avg = self.norm_avg(avg)
409
+ avg = self.act_avg(avg) # (B, C, 1, 1)
410
+ avg = avg.repeat(1, 1, height, width)
411
+
412
+ outs = torch.cat([*outs, avg], dim=1) # (B, C * (len(dilations) + 2), H, W)
413
+ outs = self.se(outs)
414
+
415
+ # reduce
416
+ outs = self.conv_out(outs)
417
+ outs = self.norm_out(outs)
418
+
419
+ outs += self.downsample(identity)
420
+ outs = self.act_out(outs)
421
+ return outs
422
+
423
+
424
+ class ViTBlock(nn.Module):
425
+ def __init__(
426
+ self,
427
+ embed_dim: int,
428
+ num_heads: int = 8,
429
+ dropout: float = 0.0,
430
+ mlp_ratio: float = 4.0,
431
+ ) -> None:
432
+ super().__init__()
433
+ assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}"
434
+ self.embed_dim, self.num_heads = embed_dim, num_heads
435
+ self.dropout, self.mlp_ratio = dropout, mlp_ratio
436
+
437
+ self.norm1 = nn.LayerNorm(embed_dim)
438
+ self.attn = nn.MultiheadAttention(
439
+ embed_dim=embed_dim,
440
+ num_heads=num_heads,
441
+ dropout=dropout,
442
+ batch_first=True
443
+ )
444
+
445
+ self.norm2 = nn.LayerNorm(embed_dim)
446
+ self.mlp = nn.Sequential(
447
+ nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
448
+ nn.GELU(),
449
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
450
+ nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
451
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity()
452
+ )
453
+ self.apply(_init_weights)
454
+
455
+ def forward(self, x: Tensor) -> Tensor:
456
+ assert len(x.shape) == 3, f"Expected input to have shape (B, N, C), but got {x.shape}"
457
+ x = x + self.attn(self.norm1(x))
458
+ x = x + self.mlp(self.norm2(x))
459
+ return x
460
+
461
+
462
+ class Conv2dLayerNorm(nn.Sequential):
463
+ """
464
+ Layer normalization applied in a convolutional fashion.
465
+ """
466
+ def __init__(self, dim: int) -> None:
467
+ super().__init__(
468
+ Rearrange("B C H W -> B H W C"),
469
+ nn.LayerNorm(dim),
470
+ Rearrange("B H W C -> B C H W")
471
+ )
472
+ self.apply(_init_weights)
473
+
474
+
475
+ class CvTAttention(nn.Module):
476
+ def __init__(
477
+ self,
478
+ embed_dim: int,
479
+ num_heads: int = 8,
480
+ dropout: float = 0.0,
481
+ q_stride: int = 1, # controls downsampling rate
482
+ kv_stride: int = 1,
483
+ ) -> None:
484
+ super().__init__()
485
+ assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}"
486
+ self.embed_dim, self.num_heads, self.dim_head = embed_dim, num_heads, embed_dim // num_heads
487
+ self.scale = self.dim_head ** -0.5
488
+ self.q_stride, self.kv_stride = q_stride, kv_stride
489
+
490
+ self.attend = nn.Softmax(dim=-1)
491
+ self.dropout = nn.Dropout(dropout)
492
+
493
+ self.to_q = DepthSeparableConv2d(
494
+ in_channels=embed_dim,
495
+ out_channels=embed_dim,
496
+ kernel_size=3,
497
+ stride=q_stride,
498
+ padding=1,
499
+ bias=False
500
+ )
501
+ self.to_k = DepthSeparableConv2d(
502
+ in_channels=embed_dim,
503
+ out_channels=embed_dim,
504
+ kernel_size=3,
505
+ stride=kv_stride,
506
+ padding=1,
507
+ bias=False
508
+ )
509
+ self.to_v = DepthSeparableConv2d(
510
+ in_channels=embed_dim,
511
+ out_channels=embed_dim,
512
+ kernel_size=3,
513
+ stride=kv_stride,
514
+ padding=1,
515
+ bias=False
516
+ )
517
+
518
+ self.to_out = nn.Sequential(
519
+ conv1x1(embed_dim, embed_dim, stride=1),
520
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity()
521
+ )
522
+
523
+ self.apply(_init_weights)
524
+
525
+ def forward(self, x: Tensor) -> Tensor:
526
+ assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
527
+ assert x.shape[1] == self.embed_dim, f"Expected input to have embedding dimension {self.embed_dim}, but got {x.shape[1]}"
528
+
529
+ q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
530
+ B, _, H, W = q.shape
531
+ q, k, v = map(lambda t: rearrange(t, "B (num_heads head_dim) H W -> (B num_heads) (H W) head_dim", num_heads=self.num_heads), (q, k, v))
532
+ attn = (q @ k.transpose(-2, -1)) * self.scale
533
+ attn = self.attend(attn)
534
+ attn = self.dropout(attn)
535
+
536
+ out = attn @ v
537
+ out = rearrange(out, "(B num_heads) (H W) head_dim -> B (num_heads head_dim) H W", B=B, H=H, W=W, num_heads=self.num_heads)
538
+ out = self.to_out(out)
539
+
540
+ return out
541
+
542
+
543
+ class CvTBlock(nn.Module):
544
+ """
545
+ Implement convolutional vision transformer block.
546
+ """
547
+ def __init__(
548
+ self,
549
+ embed_dim: int,
550
+ num_heads: int = 8,
551
+ dropout: float = 0.0,
552
+ mlp_ratio: float = 4.0,
553
+ q_stride: int = 1,
554
+ kv_stride: int = 1,
555
+ ) -> None:
556
+ super().__init__()
557
+ assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}."
558
+ self.embed_dim, self.num_heads = embed_dim, num_heads
559
+
560
+ self.norm1 = Conv2dLayerNorm(embed_dim)
561
+ self.attn = CvTAttention(embed_dim, num_heads, dropout, q_stride, kv_stride)
562
+
563
+ self.pool = nn.AvgPool2d(kernel_size=q_stride, stride=q_stride) if q_stride > 1 else nn.Identity()
564
+
565
+ self.norm2 = Conv2dLayerNorm(embed_dim)
566
+ self.mlp = nn.Sequential(
567
+ nn.Conv2d(embed_dim, int(embed_dim * mlp_ratio), kernel_size=1),
568
+ nn.GELU(),
569
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
570
+ nn.Conv2d(int(embed_dim * mlp_ratio), embed_dim, kernel_size=1),
571
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity()
572
+ )
573
+
574
+ def forward(self, x: Tensor) -> Tensor:
575
+ x = self.pool(x) + self.attn(self.norm1(x))
576
+ x = x + self.mlp(self.norm2(x))
577
+ return x
578
+
579
+
580
+ class ConvAdapter(nn.Module):
581
+ def __init__(
582
+ self,
583
+ in_channels: int,
584
+ bottleneck_channels: int = 16,
585
+ ) -> None:
586
+ super().__init__()
587
+ assert in_channels > 0, f"Expected input_channels to be greater than 0, but got {in_channels}"
588
+ assert bottleneck_channels > 0, f"Expected bottleneck_channels to be greater than 0, but got {bottleneck_channels}"
589
+
590
+ self.adapter = nn.Sequential(
591
+ nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1),
592
+ nn.GELU(),
593
+ nn.Conv2d(bottleneck_channels, in_channels, kernel_size=1),
594
+ )
595
+ nn.init.zeros_(self.adapter[2].weight)
596
+ nn.init.zeros_(self.adapter[2].bias)
597
+
598
+ def forward(self, x: Tensor) -> Tensor:
599
+ assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
600
+ return x + self.adapter(x)
601
+
602
+
603
+ class ViTAdapter(nn.Module):
604
+ def __init__(self, input_dim, bottleneck_dim):
605
+ super().__init__()
606
+ self.adapter = nn.Sequential(
607
+ nn.Linear(input_dim, bottleneck_dim),
608
+ nn.GELU(), # ViT中常用GELU作为激活函数
609
+ nn.Linear(bottleneck_dim, input_dim)
610
+ )
611
+ nn.init.zeros_(self.adapter[2].weight)
612
+ nn.init.zeros_(self.adapter[2].bias)
613
+
614
+ def forward(self, x: Tensor) -> Tensor:
615
+ assert len(x.shape) == 3, f"Expected input to have shape (B, N, C), but got {x.shape}"
616
+ return x + self.adapter(x)
617
+
models/utils/carafe.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def carafe_forward(
6
+ features: torch.Tensor,
7
+ masks: torch.Tensor,
8
+ kernel_size: int,
9
+ group_size: int,
10
+ scale_factor: int
11
+ ) -> torch.Tensor:
12
+ """
13
+ Pure-PyTorch implementation of the CARAFE upsampling operator.
14
+
15
+ Args:
16
+ features (Tensor): Input feature map of shape (N, C, H, W).
17
+ masks (Tensor): Reassembly kernel weights of shape
18
+ (N, kernel_size*kernel_size*group_size, H_out, W_out),
19
+ where H_out = H*scale_factor and W_out = W*scale_factor.
20
+ kernel_size (int): The spatial size of the reassembly kernel.
21
+ group_size (int): The group size to divide channels. Must divide C.
22
+ scale_factor (int): The upsampling factor.
23
+
24
+ Returns:
25
+ Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor).
26
+ """
27
+ N, C, H, W = features.size()
28
+ out_H, out_W = H * scale_factor, W * scale_factor
29
+ num_channels = C // group_size # channels per group
30
+
31
+ # Reshape features to (N, group_size, num_channels, H, W)
32
+ features = features.view(N, group_size, num_channels, H, W)
33
+ # Merge batch and group dims for unfolding
34
+ features_reshaped = features.view(N * group_size, num_channels, H, W)
35
+ # Extract local patches; use padding so that output spatial dims match input
36
+ patches = F.unfold(features_reshaped, kernel_size=kernel_size,
37
+ padding=(kernel_size - 1) // 2)
38
+ # patches shape: (N*group_size, num_channels*kernel_size*kernel_size, H*W)
39
+ # Reshape to (N, group_size, num_channels, kernel_size*kernel_size, H, W)
40
+ patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W)
41
+ # Flatten spatial dimensions: now (N, group_size, num_channels, kernel_size*kernel_size, H*W)
42
+ patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W)
43
+
44
+ # For each output pixel location, determine the corresponding base input index.
45
+ # For an output coordinate (oh, ow), the corresponding input index is:
46
+ # h = oh // scale_factor, w = ow // scale_factor, linear index = h * W + w.
47
+ device = features.device
48
+ # Create coordinate indices for output
49
+ h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') # (out_H,)
50
+ w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') # (out_W,)
51
+ # Form a 2D grid of base indices (shape: out_H x out_W)
52
+ h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) # (out_H, out_W)
53
+ w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) # (out_H, out_W)
54
+ base_idx = (h_idx * W + w_idx).view(-1) # (out_H*out_W,)
55
+
56
+ # Expand base_idx so that it can index the last dimension of patches:
57
+ # Desired shape for gathering: (N, group_size, num_channels, kernel_size*kernel_size, out_H*out_W)
58
+ base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1)
59
+ # Gather patches corresponding to each output location
60
+ gathered_patches = torch.gather(patches, -1, base_idx)
61
+ # Reshape gathered patches to (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
62
+ gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W)
63
+
64
+ # Reshape masks to separate groups.
65
+ # Expected mask shape: (N, kernel_size*kernel_size*group_size, out_H, out_W)
66
+ # Reshape to: (N, group_size, kernel_size*kernel_size, out_H, out_W)
67
+ masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W)
68
+ # For multiplication, add a channel dimension so that masks shape becomes
69
+ # (N, group_size, 1, kernel_size*kernel_size, out_H, out_W)
70
+ masks = masks.unsqueeze(2)
71
+ # Expand masks to match gathered_patches: (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
72
+ masks = masks.expand(-1, -1, num_channels, -1, -1, -1)
73
+
74
+ # Multiply patches with masks and sum over the kernel dimension.
75
+ # This yields the reassembled features for each output location.
76
+ out = (gathered_patches * masks).sum(dim=3) # shape: (N, group_size, num_channels, out_H, out_W)
77
+ # Reshape back to (N, C, out_H, out_W)
78
+ out = out.view(N, C, out_H, out_W)
79
+ return out
80
+
81
+
82
+ class CARAFE(nn.Module):
83
+ """
84
+ CARAFE: Content-Aware ReAssembly of Features
85
+
86
+ This PyTorch module implements the CARAFE upsampling operator in pure Python.
87
+ Given an input feature map and its corresponding reassembly masks, the module
88
+ reassembles features from local patches to produce a higher-resolution output.
89
+
90
+ Args:
91
+ kernel_size (int): Reassembly kernel size.
92
+ group_size (int): Group size for channel grouping (must divide number of channels).
93
+ scale_factor (int): Upsample ratio.
94
+ """
95
+ def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
96
+ super(CARAFE, self).__init__()
97
+ self.kernel_size = kernel_size
98
+ self.group_size = group_size
99
+ self.scale_factor = scale_factor
100
+
101
+ def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
102
+ return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor)
103
+
104
+
105
+ class CARAFEPack(nn.Module):
106
+ """
107
+ A unified package of the CARAFE upsampler that contains:
108
+ 1) A channel compressor.
109
+ 2) A content encoder that predicts reassembly masks.
110
+ 3) The CARAFE operator.
111
+
112
+ This is modeled after the official CARAFE package.
113
+
114
+ Args:
115
+ channels (int): Number of input feature channels.
116
+ scale_factor (int): Upsample ratio.
117
+ up_kernel (int): Kernel size for the CARAFE operator.
118
+ up_group (int): Group size for the CARAFE operator.
119
+ encoder_kernel (int): Kernel size of the content encoder.
120
+ encoder_dilation (int): Dilation rate for the content encoder.
121
+ compressed_channels (int): Output channels for the channel compressor.
122
+ """
123
+ def __init__(
124
+ self,
125
+ channels: int,
126
+ scale_factor: int,
127
+ up_kernel: int = 5,
128
+ up_group: int = 1,
129
+ encoder_kernel: int = 3,
130
+ encoder_dilation: int = 1,
131
+ compressed_channels: int = 64
132
+ ):
133
+ super(CARAFEPack, self).__init__()
134
+ self.channels = channels
135
+ self.scale_factor = scale_factor
136
+ self.up_kernel = up_kernel
137
+ self.up_group = up_group
138
+ self.encoder_kernel = encoder_kernel
139
+ self.encoder_dilation = encoder_dilation
140
+ self.compressed_channels = compressed_channels
141
+
142
+ # Compress input channels.
143
+ self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1)
144
+ # Predict reassembly masks.
145
+ self.content_encoder = nn.Conv2d(
146
+ compressed_channels,
147
+ up_kernel * up_kernel * up_group * scale_factor * scale_factor,
148
+ kernel_size=encoder_kernel,
149
+ padding=int((encoder_kernel - 1) * encoder_dilation / 2),
150
+ dilation=encoder_dilation
151
+ )
152
+ # Initialize weights (using Xavier for conv layers).
153
+ nn.init.xavier_uniform_(self.channel_compressor.weight)
154
+ nn.init.xavier_uniform_(self.content_encoder.weight)
155
+ if self.channel_compressor.bias is not None:
156
+ nn.init.constant_(self.channel_compressor.bias, 0)
157
+ if self.content_encoder.bias is not None:
158
+ nn.init.constant_(self.content_encoder.bias, 0)
159
+
160
+ def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor:
161
+ """
162
+ Normalize and reshape the mask.
163
+ Applies pixel shuffle to upsample the predicted kernel weights and then
164
+ applies softmax normalization across the kernel dimension.
165
+
166
+ Args:
167
+ mask (Tensor): Predicted mask of shape (N, out_channels, H, W).
168
+
169
+ Returns:
170
+ Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale).
171
+ """
172
+ # Pixel shuffle to rearrange and upsample the mask.
173
+ mask = F.pixel_shuffle(mask, self.scale_factor)
174
+ N, mask_c, H, W = mask.size()
175
+ # Determine the number of channels per kernel
176
+ mask_channel = mask_c // (self.up_kernel ** 2)
177
+ mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W)
178
+ mask = F.softmax(mask, dim=2)
179
+ mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous()
180
+ return mask
181
+
182
+ def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
183
+ return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor)
184
+
185
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ compressed_x = self.channel_compressor(x)
187
+ mask = self.content_encoder(compressed_x)
188
+ mask = self.kernel_normalizer(mask)
189
+ out = self.feature_reassemble(x, mask)
190
+ return out
191
+
192
+
193
+ # === Example Usage ===
194
+ if __name__ == '__main__':
195
+ # Create dummy input: batch size 2, 64 channels, 32x32 spatial resolution.
196
+ x = torch.randn(2, 64, 32, 32).cuda() # assuming GPU available
197
+ # Define CARAFEPack with upsample ratio 2.
198
+ # For example, use kernel size 5, group size 1.
199
+ upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda()
200
+ # Get upsampled feature map.
201
+ out = upsampler(x)
202
+ print("Input shape: ", x.shape)
203
+ print("Output shape:", out.shape) # Expected shape: (2, 64, 64, 64)
models/utils/downsample.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+
3
+ from typing import Union
4
+
5
+ from .blocks import DepthSeparableConv2d, conv1x1, conv3x3
6
+ from .utils import _init_weights
7
+
8
+
9
+ class ConvDownsample(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels: int,
13
+ out_channels: int,
14
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
15
+ activation: nn.Module = nn.ReLU(inplace=True),
16
+ groups: int = 1,
17
+ ) -> None:
18
+ super().__init__()
19
+ assert isinstance(groups, int) and groups > 0, f"Number of groups should be an integer greater than 0, but got {groups}."
20
+ assert in_channels % groups == 0, f"Number of input channels {in_channels} should be divisible by number of groups {groups}."
21
+ assert out_channels % groups == 0, f"Number of output channels {out_channels} should be divisible by number of groups {groups}."
22
+ self.grouped_conv = groups > 1
23
+
24
+ # conv1 is used for downsampling
25
+ # self.conv1 = nn.Conv2d(
26
+ # in_channels=in_channels,
27
+ # out_channels=in_channels,
28
+ # kernel_size=2,
29
+ # stride=2,
30
+ # padding=0,
31
+ # bias=not norm_layer,
32
+ # groups=groups,
33
+ # )
34
+ # if self.grouped_conv:
35
+ # self.conv1_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
36
+ self.conv1 = nn.AvgPool2d(kernel_size=2, stride=2) # downsample by 2
37
+ if self.grouped_conv:
38
+ self.conv1_1x1 = nn.Identity()
39
+
40
+ self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
41
+ self.act1 = activation
42
+
43
+ self.conv2 = conv3x3(
44
+ in_channels=in_channels,
45
+ out_channels=in_channels,
46
+ stride=1,
47
+ groups=groups,
48
+ bias=not norm_layer,
49
+ )
50
+ if self.grouped_conv:
51
+ self.conv2_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
52
+
53
+ self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
54
+ self.act2 = activation
55
+
56
+ self.conv3 = conv3x3(
57
+ in_channels=in_channels,
58
+ out_channels=out_channels,
59
+ stride=1,
60
+ groups=groups,
61
+ bias=not norm_layer,
62
+ )
63
+ if self.grouped_conv:
64
+ self.conv3_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
65
+
66
+ self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
67
+ self.act3 = activation
68
+
69
+ self.downsample = nn.Sequential(
70
+ nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
71
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
72
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
73
+ )
74
+
75
+ self.apply(_init_weights)
76
+
77
+ def forward(self, x: Tensor) -> Tensor:
78
+ identity = x
79
+
80
+ # downsample
81
+ out = self.conv1(x)
82
+ out = self.conv1_1x1(out) if self.grouped_conv else out
83
+ out = self.norm1(out)
84
+ out = self.act1(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.conv2_1x1(out) if self.grouped_conv else out
88
+ out = self.norm2(out)
89
+ out = self.act2(out)
90
+
91
+ out = self.conv3(out)
92
+ out = self.conv3_1x1(out) if self.grouped_conv else out
93
+ out = self.norm3(out)
94
+
95
+ # shortcut
96
+ out += self.downsample(identity)
97
+ out = self.act3(out)
98
+ return out
99
+
100
+
101
+ class LightConvDownsample(nn.Module):
102
+ def __init__(
103
+ self,
104
+ in_channels: int,
105
+ out_channels: int,
106
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
107
+ activation: nn.Module = nn.ReLU(inplace=True),
108
+ ) -> None:
109
+ super().__init__()
110
+ self.conv1 = DepthSeparableConv2d(
111
+ in_channels=in_channels,
112
+ out_channels=in_channels,
113
+ kernel_size=2,
114
+ stride=2,
115
+ padding=0,
116
+ bias=not norm_layer,
117
+ )
118
+ self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
119
+ self.act1 = activation
120
+
121
+ self.conv2 = DepthSeparableConv2d(
122
+ in_channels=in_channels,
123
+ out_channels=out_channels,
124
+ kernel_size=3,
125
+ stride=1,
126
+ padding=1,
127
+ bias=not norm_layer,
128
+ )
129
+ self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
130
+ self.act2 = activation
131
+
132
+ self.conv3 = DepthSeparableConv2d(
133
+ in_channels=out_channels,
134
+ out_channels=out_channels,
135
+ kernel_size=3,
136
+ stride=1,
137
+ padding=1,
138
+ bias=not norm_layer,
139
+ )
140
+ self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
141
+ self.act3 = activation
142
+
143
+ self.downsample = nn.Sequential(
144
+ nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
145
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
146
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
147
+ )
148
+
149
+ self.apply(_init_weights)
150
+
151
+ def forward(self, x: Tensor) -> Tensor:
152
+ identity = x
153
+
154
+ # downsample
155
+ out = self.conv1(x)
156
+ out = self.norm1(out)
157
+ out = self.act1(out)
158
+
159
+ # refine 1
160
+ out = self.conv2(out)
161
+ out = self.norm2(out)
162
+ out = self.act2(out)
163
+
164
+ # refine 2
165
+ out = self.conv3(out)
166
+ out = self.norm3(out)
167
+
168
+ # shortcut
169
+ out += self.downsample(identity)
170
+ out = self.act3(out)
171
+ return x
172
+
173
+
174
+ class LighterConvDownsample(nn.Module):
175
+ def __init__(
176
+ self,
177
+ in_channels: int,
178
+ out_channels: int,
179
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
180
+ activation: nn.Module = nn.ReLU(inplace=True),
181
+ ) -> None:
182
+ super().__init__()
183
+ self.conv1 = DepthSeparableConv2d(
184
+ in_channels=in_channels,
185
+ out_channels=in_channels,
186
+ kernel_size=2,
187
+ stride=2,
188
+ padding=0,
189
+ bias=not norm_layer,
190
+ )
191
+ self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
192
+ self.act1 = activation
193
+
194
+ self.conv2 = conv3x3(
195
+ in_channels=in_channels,
196
+ out_channels=in_channels,
197
+ stride=1,
198
+ groups=in_channels,
199
+ bias=not norm_layer,
200
+ )
201
+ self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
202
+ self.act2 = activation
203
+
204
+ self.conv3 = conv1x1(
205
+ in_channels=in_channels,
206
+ out_channels=out_channels,
207
+ stride=1,
208
+ bias=not norm_layer,
209
+ )
210
+ self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
211
+ self.act3 = activation
212
+
213
+ self.downsample = nn.Sequential(
214
+ nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
215
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
216
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
217
+ )
218
+
219
+ def forward(self, x: Tensor) -> Tensor:
220
+ identity = x
221
+
222
+ # downsample
223
+ out = self.conv1(x)
224
+ out = self.norm1(out)
225
+ out = self.act1(out)
226
+
227
+ # refine, depthwise conv
228
+ out = self.conv2(out)
229
+ out = self.norm2(out)
230
+ out = self.act2(out)
231
+
232
+ # refine, pointwise conv
233
+ out = self.conv3(out)
234
+ out = self.norm3(out)
235
+
236
+ # shortcut
237
+ out += self.downsample(identity)
238
+ out = self.act3(out)
239
+ return out
models/utils/multi_scale.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from typing import List
4
+ from einops import rearrange
5
+
6
+ from .blocks import conv3x3, conv1x1, Conv2dLayerNorm, _init_weights
7
+
8
+
9
+ class MultiScale(nn.Module):
10
+ def __init__(
11
+ self,
12
+ channels: int,
13
+ scales: List[int],
14
+ heads: int = 8,
15
+ groups: int = 1,
16
+ mlp_ratio: float = 4.0,
17
+ ) -> None:
18
+ super().__init__()
19
+ assert channels > 0, "channels should be a positive integer"
20
+ assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all([scale > 0 for scale in scales]), "scales should be a list or tuple of positive integers"
21
+ assert heads > 0 and channels % heads == 0, "heads should be a positive integer and channels should be divisible by heads"
22
+ assert groups > 0 and channels % groups == 0, "groups should be a positive integer and channels should be divisible by groups"
23
+ scales = sorted(scales)
24
+ self.scales = scales
25
+ self.num_scales = len(scales) + 1 # +1 for the original feature map
26
+ self.heads = heads
27
+ self.groups = groups
28
+
29
+ # modules that generate multi-scale feature maps
30
+ self.scale_0 = nn.Sequential(
31
+ conv1x1(channels, channels, stride=1, bias=False),
32
+ Conv2dLayerNorm(channels),
33
+ nn.GELU(),
34
+ )
35
+ for scale in scales:
36
+ setattr(self, f"conv_{scale}", nn.Sequential(
37
+ conv3x3(
38
+ in_channels=channels,
39
+ out_channels=channels,
40
+ stride=1,
41
+ groups=groups,
42
+ dilation=scale,
43
+ bias=False,
44
+ ),
45
+ conv1x1(channels, channels, stride=1, bias=False) if groups > 1 else nn.Identity(),
46
+ Conv2dLayerNorm(channels),
47
+ nn.GELU(),
48
+ ))
49
+
50
+ # modules that fuse multi-scale feature maps
51
+ self.norm_attn = Conv2dLayerNorm(channels)
52
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_scales + 1, channels, 1, 1) / channels ** 0.5)
53
+ self.to_q = conv1x1(channels, channels, stride=1, bias=False)
54
+ self.to_k = conv1x1(channels, channels, stride=1, bias=False)
55
+ self.to_v = conv1x1(channels, channels, stride=1, bias=False)
56
+
57
+ self.scale = (channels // heads) ** -0.5
58
+
59
+ self.attend = nn.Softmax(dim=-1)
60
+
61
+ self.to_out = conv1x1(channels, channels, stride=1)
62
+
63
+ # modules that refine multi-scale feature maps
64
+ self.norm_mlp = Conv2dLayerNorm(channels)
65
+ self.mlp = nn.Sequential(
66
+ conv1x1(channels, channels * mlp_ratio, stride=1),
67
+ nn.GELU(),
68
+ conv1x1(channels * mlp_ratio, channels, stride=1),
69
+ )
70
+
71
+ self.apply(_init_weights)
72
+
73
+ def _forward_attn(self, x: Tensor) -> Tensor:
74
+ assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
75
+ x = [self.scale_0(x)] + [getattr(self, f"conv_{scale}")(x) for scale in self.scales]
76
+
77
+ x = torch.stack(x, dim=1) # (B, S, C, H, W)
78
+ x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1) # (B, S+1, C, H, W)
79
+ x = x + self.pos_embed # (B, S+1, C, H, W)
80
+
81
+ x = rearrange(x, "B S C H W -> (B S) C H W") # (B*(S+1), C, H, W)
82
+ x = self.norm_attn(x) # (B*(S+1), C, H, W)
83
+ x = rearrange(x, "(B S) C H W -> B S C H W", S=self.num_scales + 1) # (B, S+1, C, H, W)
84
+
85
+ q = self.to_q(x[:, 0]) # (B, C, H, W)
86
+ k = self.to_k(rearrange(x, "B S C H W -> (B S) C H W"))
87
+ v = self.to_v(rearrange(x, "B S C H W -> (B S) C H W"))
88
+
89
+ q = rearrange(q, "B (h d) H W -> B h H W 1 d", h=self.heads)
90
+ k = rearrange(k, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)
91
+ v = rearrange(v, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)
92
+
93
+ attn = q @ k.transpose(-2, -1) * self.scale # (B, h, H, W, 1, S+1)
94
+ attn = self.attend(attn) # (B, h, H, W, 1, S+1)
95
+ out = attn @ v # (B, h, H, W, 1, d)
96
+
97
+ out = rearrange(out, "B h H W 1 d -> B (h d) H W") # (B, C, H, W)
98
+
99
+ out = self.to_out(out) # (B, C, H, W)
100
+ return out
101
+
102
+ def _forward_mlp(self, x: Tensor) -> Tensor:
103
+ assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
104
+ x = self.norm_mlp(x)
105
+ x = self.mlp(x)
106
+ return x
107
+
108
+ def forward(self, x: Tensor) -> Tensor:
109
+ x = x + self._forward_attn(x)
110
+ x = x + self._forward_mlp(x)
111
+ return x
112
+
models/utils/refine.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ from typing import Union
3
+
4
+ from .utils import _init_weights
5
+ from .blocks import BasicBlock, LightBasicBlock, conv1x1, conv3x3
6
+
7
+
8
+ class ConvRefine(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels: int,
12
+ out_channels: int,
13
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
14
+ activation: nn.Module = nn.ReLU(inplace=True),
15
+ groups: int = 1,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.refine = BasicBlock(
19
+ in_channels=in_channels,
20
+ out_channels=out_channels,
21
+ norm_layer=norm_layer,
22
+ activation=activation,
23
+ groups=groups,
24
+ )
25
+ self.apply(_init_weights)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return self.refine(x)
29
+
30
+
31
+ class LightConvRefine(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ out_channels: int,
36
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
37
+ activation: nn.Module = nn.ReLU(inplace=True),
38
+ ) -> None:
39
+ super().__init__()
40
+ self.refine = LightBasicBlock(
41
+ in_channels=in_channels,
42
+ out_channels=out_channels,
43
+ norm_layer=norm_layer,
44
+ activation=activation,
45
+ )
46
+ self.apply(_init_weights)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ return self.refine(x)
50
+
51
+
52
+ class LighterConvRefine(nn.Module):
53
+ def __init__(
54
+ self,
55
+ in_channels: int,
56
+ out_channels: int,
57
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
58
+ activation: nn.Module = nn.ReLU(inplace=True),
59
+ ) -> None:
60
+ super().__init__()
61
+ # depthwise separable convolution
62
+ self.conv1 = conv3x3(
63
+ in_channels=in_channels,
64
+ out_channels=in_channels,
65
+ stride=1,
66
+ groups=in_channels,
67
+ bias=not norm_layer,
68
+ )
69
+ self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
70
+ self.act1 = activation
71
+
72
+ self.conv2 = conv1x1(
73
+ in_channels=in_channels,
74
+ out_channels=out_channels,
75
+ stride=1,
76
+ bias=not norm_layer,
77
+ )
78
+ self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
79
+ self.act2 = activation
80
+
81
+ if in_channels != out_channels:
82
+ self.downsample = nn.Sequential(
83
+ conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
84
+ norm_layer(out_channels) if norm_layer else nn.Identity(),
85
+ )
86
+ else:
87
+ self.downsample = nn.Identity()
88
+
89
+ self.apply(_init_weights)
90
+
91
+ def forward(self, x: Tensor) -> Tensor:
92
+ identity = x
93
+
94
+ out = self.conv1(x)
95
+ out = self.norm1(out)
96
+ out = self.act1(out)
97
+
98
+ out = self.conv2(out)
99
+ out = self.norm2(out)
100
+
101
+ out += self.downsample(identity)
102
+ out = self.act2(out)
103
+ return out
models/utils/upsample.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ from torch.nn import functional as F
3
+
4
+ from typing import Union
5
+ from functools import partial
6
+
7
+ from .utils import _init_weights
8
+ from .refine import ConvRefine, LightConvRefine, LighterConvRefine
9
+
10
+
11
+ class ConvUpsample(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_channels: int,
15
+ out_channels: int,
16
+ scale_factor: int = 2,
17
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
18
+ activation: nn.Module = nn.ReLU(inplace=True),
19
+ groups: int = 1,
20
+ ) -> None:
21
+ super().__init__()
22
+ assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
23
+ self.scale_factor = scale_factor
24
+ self.upsample = partial(
25
+ F.interpolate,
26
+ scale_factor=scale_factor,
27
+ mode="bilinear",
28
+ align_corners=False,
29
+ recompute_scale_factor=False,
30
+ antialias=False,
31
+ ) if scale_factor > 1 else nn.Identity()
32
+
33
+ self.refine = ConvRefine(
34
+ in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ norm_layer=norm_layer,
37
+ activation=activation,
38
+ groups=groups,
39
+ )
40
+
41
+ self.apply(_init_weights)
42
+
43
+ def forward(self, x: Tensor) -> Tensor:
44
+ x = self.upsample(x)
45
+ x = self.refine(x)
46
+ return x
47
+
48
+
49
+ class LightConvUpsample(nn.Module):
50
+ def __init__(
51
+ self,
52
+ in_channels: int,
53
+ out_channels: int,
54
+ scale_factor: int = 2,
55
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
56
+ activation: nn.Module = nn.ReLU(inplace=True),
57
+ ) -> None:
58
+ super().__init__()
59
+ assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
60
+ self.scale_factor = scale_factor
61
+ self.upsample = partial(
62
+ F.interpolate,
63
+ scale_factor=scale_factor,
64
+ mode="bilinear",
65
+ align_corners=False,
66
+ recompute_scale_factor=False,
67
+ antialias=False,
68
+ ) if scale_factor > 1 else nn.Identity()
69
+
70
+ self.refine = LightConvRefine(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ norm_layer=norm_layer,
74
+ activation=activation,
75
+ )
76
+
77
+ self.apply(_init_weights)
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ x = self.upsample(x)
81
+ x = self.refine(x)
82
+ return x
83
+
84
+
85
+ class LighterConvUpsample(nn.Module):
86
+ def __init__(
87
+ self,
88
+ in_channels: int,
89
+ out_channels: int,
90
+ scale_factor: int = 2,
91
+ norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
92
+ activation: nn.Module = nn.ReLU(inplace=True),
93
+ ) -> None:
94
+ super().__init__()
95
+ assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
96
+ self.scale_factor = scale_factor
97
+ self.upsample = partial(
98
+ F.interpolate,
99
+ scale_factor=scale_factor,
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ recompute_scale_factor=False,
103
+ antialias=False,
104
+ ) if scale_factor > 1 else nn.Identity()
105
+
106
+ self.refine = LighterConvRefine(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ norm_layer=norm_layer,
110
+ activation=activation,
111
+ )
112
+
113
+ self.apply(_init_weights)
114
+
115
+ def forward(self, x: Tensor) -> Tensor:
116
+ x = self.upsample(x)
117
+ x = self.refine(x)
118
+ return x
models/utils/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Any, Optional, Union
5
+ from types import FunctionType
6
+ from itertools import repeat
7
+ from collections.abc import Iterable
8
+
9
+
10
+ def _log_api_usage_once(obj: Any) -> None:
11
+
12
+ """
13
+ Logs API usage(module and name) within an organization.
14
+ In a large ecosystem, it's often useful to track the PyTorch and
15
+ TorchVision APIs usage. This API provides the similar functionality to the
16
+ logging module in the Python stdlib. It can be used for debugging purpose
17
+ to log which methods are used and by default it is inactive, unless the user
18
+ manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
19
+ Please note it is triggered only once for the same API call within a process.
20
+ It does not collect any data from open-source users since it is no-op by default.
21
+ For more information, please refer to
22
+ * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
23
+ * Logging policy: https://github.com/pytorch/vision/issues/5052;
24
+
25
+ Args:
26
+ obj (class instance or method): an object to extract info from.
27
+ """
28
+ module = obj.__module__
29
+ if not module.startswith("torchvision"):
30
+ module = f"torchvision.internal.{module}"
31
+ name = obj.__class__.__name__
32
+ if isinstance(obj, FunctionType):
33
+ name = obj.__name__
34
+ torch._C._log_api_usage_once(f"{module}.{name}")
35
+
36
+
37
+ def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
38
+ """
39
+ Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
40
+ Otherwise, we will make a tuple of length n, all with value of x.
41
+ reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
42
+
43
+ Args:
44
+ x (Any): input value
45
+ n (int): length of the resulting tuple
46
+ """
47
+ if isinstance(x, Iterable):
48
+ return tuple(x)
49
+ return tuple(repeat(x, n))
50
+
51
+
52
+ def _init_weights(model: nn.Module) -> None:
53
+ for m in model.modules():
54
+ if isinstance(m, nn.Conv2d):
55
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
56
+ if m.bias is not None:
57
+ nn.init.constant_(m.bias, 0.)
58
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
59
+ nn.init.constant_(m.weight, 1.)
60
+ if m.bias is not None:
61
+ nn.init.constant_(m.bias, 0.)
62
+ elif isinstance(m, nn.Linear):
63
+ nn.init.normal_(m.weight, std=0.01)
64
+ if m.bias is not None:
65
+ nn.init.constant_(m.bias, 0.)
66
+
67
+
68
+ def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor:
69
+ assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}."
70
+ return F.interpolate(
71
+ pos_embed.unsqueeze(0),
72
+ size=size,
73
+ scale_factor=scale_factor,
74
+ mode="bicubic",
75
+ align_corners=False,
76
+ antialias=True,
77
+ ).squeeze(0)