Spaces:
Running
on
Zero
Running
on
Zero
2025-07-31 18:59 🐣
Browse files- .gitignore +163 -0
- README.md +5 -5
- app.py +459 -0
- models/__init__.py +155 -0
- models/clip_ebc/__init__.py +7 -0
- models/clip_ebc/convnext.py +199 -0
- models/clip_ebc/mobileclip.py +197 -0
- models/clip_ebc/model.py +272 -0
- models/clip_ebc/resnet.py +236 -0
- models/clip_ebc/utils.py +137 -0
- models/clip_ebc/vit.py +372 -0
- models/ebc/__init__.py +3 -0
- models/ebc/cannet.py +105 -0
- models/ebc/csrnet.py +104 -0
- models/ebc/hrnet.py +195 -0
- models/ebc/model.py +199 -0
- models/ebc/timm_models.py +318 -0
- models/ebc/utils.py +37 -0
- models/ebc/vgg.py +255 -0
- models/ebc/vit.py +323 -0
- models/utils/__init__.py +56 -0
- models/utils/blocks.py +617 -0
- models/utils/carafe.py +203 -0
- models/utils/downsample.py +239 -0
- models/utils/multi_scale.py +112 -0
- models/utils/refine.py +103 -0
- models/utils/upsample.py +118 -0
- models/utils/utils.py +77 -0
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MacOS
|
2 |
+
**/.DS_Store
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/#use-with-ide
|
113 |
+
.pdm.toml
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title: ZIP
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.39.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
short_description: The crowd counting model ZIP-B
|
12 |
---
|
|
|
1 |
---
|
2 |
+
title: ZIP
|
3 |
+
emoji: 🔢
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.39.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
short_description: The crowd counting model ZIP-B
|
12 |
---
|
app.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
|
6 |
+
from torch import Tensor
|
7 |
+
import spaces
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import gradio as gr
|
12 |
+
from matplotlib import cm
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from warnings import warn
|
15 |
+
|
16 |
+
from models import get_model
|
17 |
+
|
18 |
+
|
19 |
+
mean = (0.485, 0.456, 0.406)
|
20 |
+
std = (0.229, 0.224, 0.225)
|
21 |
+
alpha = 0.8
|
22 |
+
EPS = 1e-8
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
|
25 |
+
|
26 |
+
pretrained_datasets = {
|
27 |
+
"ZIP-B": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF" "NWPU-Crowd"],
|
28 |
+
"ZIP-S": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
|
29 |
+
"ZIP-T": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
|
30 |
+
"ZIP-N": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
|
31 |
+
"ZIP-P": ["ShanghaiTech A", "ShanghaiTech B", "UCF-QNRF"],
|
32 |
+
}
|
33 |
+
|
34 |
+
# -----------------------------
|
35 |
+
# Define the model architecture
|
36 |
+
# -----------------------------
|
37 |
+
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
|
38 |
+
""" Load the model weights from the Hugging Face Hub."""
|
39 |
+
global loaded_model
|
40 |
+
# Build model
|
41 |
+
|
42 |
+
model_info_path = hf_hub_download(
|
43 |
+
repo_id=f"Yiming-M/{variant}",
|
44 |
+
filename=f"checkpoints/{dataset}/best_{metric}.pth",
|
45 |
+
)
|
46 |
+
|
47 |
+
model = get_model(model_info_path=model_info_path)
|
48 |
+
model.eval()
|
49 |
+
loaded_model = model
|
50 |
+
|
51 |
+
|
52 |
+
def _calc_size(
|
53 |
+
img_w: int,
|
54 |
+
img_h: int,
|
55 |
+
min_size: int,
|
56 |
+
max_size: int,
|
57 |
+
base: int = 32
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
This function generates a new size for an image while keeping the aspect ratio. The new size should be within the given range (min_size, max_size).
|
61 |
+
|
62 |
+
Args:
|
63 |
+
img_w (int): The width of the image.
|
64 |
+
img_h (int): The height of the image.
|
65 |
+
min_size (int): The minimum size of the edges of the image.
|
66 |
+
max_size (int): The maximum size of the edges of the image.
|
67 |
+
# base (int): The base number to which the new size should be a multiple of.
|
68 |
+
"""
|
69 |
+
assert min_size % base == 0, f"min_size ({min_size}) must be a multiple of {base}"
|
70 |
+
if max_size != float("inf"):
|
71 |
+
assert max_size % base == 0, f"max_size ({max_size}) must be a multiple of {base} if provided"
|
72 |
+
|
73 |
+
assert min_size <= max_size, f"min_size ({min_size}) must be less than or equal to max_size ({max_size})"
|
74 |
+
|
75 |
+
aspect_ratios = (img_w / img_h, img_h / img_w)
|
76 |
+
if min_size / max_size <= min(aspect_ratios) <= max(aspect_ratios) <= max_size / min_size: # possible to resize and preserve the aspect ratio
|
77 |
+
if min_size <= min(img_w, img_h) <= max(img_w, img_h) <= max_size: # already within the range, no need to resize
|
78 |
+
ratio = 1.
|
79 |
+
elif min(img_w, img_h) < min_size: # smaller than the minimum size, resize to the minimum size
|
80 |
+
ratio = min_size / min(img_w, img_h)
|
81 |
+
else: # larger than the maximum size, resize to the maximum size
|
82 |
+
ratio = max_size / max(img_w, img_h)
|
83 |
+
|
84 |
+
new_w, new_h = int(round(img_w * ratio / base) * base), int(round(img_h * ratio / base) * base)
|
85 |
+
new_w = max(min_size, min(max_size, new_w))
|
86 |
+
new_h = max(min_size, min(max_size, new_h))
|
87 |
+
return new_w, new_h
|
88 |
+
|
89 |
+
else: # impossible to resize and preserve the aspect ratio
|
90 |
+
msg = f"Impossible to resize {img_w}x{img_h} image while preserving the aspect ratio to a size within the range ({min_size}, {max_size}). Will not limit the maximum size."
|
91 |
+
warn(msg)
|
92 |
+
return _calc_size(img_w, img_h, min_size, float("inf"), base)
|
93 |
+
|
94 |
+
|
95 |
+
# -----------------------------
|
96 |
+
# Preprocessing function
|
97 |
+
# -----------------------------
|
98 |
+
# Adjust the image transforms to match what your model expects.
|
99 |
+
def transform(image: Image.Image, dataset_name: str) -> Tensor:
|
100 |
+
assert isinstance(image, Image.Image), "Input must be a PIL Image"
|
101 |
+
image_tensor = TF.to_tensor(image)
|
102 |
+
|
103 |
+
if dataset_name == "sha":
|
104 |
+
min_size = 448
|
105 |
+
max_size = float("inf")
|
106 |
+
elif dataset_name == "shb":
|
107 |
+
min_size = 448
|
108 |
+
max_size = float("inf")
|
109 |
+
elif dataset_name == "qnrf":
|
110 |
+
min_size = 448
|
111 |
+
max_size = 2048
|
112 |
+
elif dataset_name == "nwpu":
|
113 |
+
min_size = 448
|
114 |
+
max_size = 3072
|
115 |
+
|
116 |
+
image_height, image_width = image_tensor.shape[-2:]
|
117 |
+
new_width, new_height = _calc_size(
|
118 |
+
img_w=image_width,
|
119 |
+
img_h=image_height,
|
120 |
+
min_size=min_size,
|
121 |
+
max_size=max_size,
|
122 |
+
base=32
|
123 |
+
)
|
124 |
+
if new_height != image_height or new_width != image_width:
|
125 |
+
image_tensor = TF.resize(image_tensor, size=(new_height, new_width), interpolation=TF.InterpolationMode.LANCZOS, antialias=True)
|
126 |
+
|
127 |
+
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
|
128 |
+
return image_tensor.unsqueeze(0) # Add batch dimension
|
129 |
+
|
130 |
+
|
131 |
+
def _sliding_window_predict(
|
132 |
+
model: nn.Module,
|
133 |
+
image: Tensor,
|
134 |
+
window_size: int,
|
135 |
+
stride: int,
|
136 |
+
max_num_windows: int = 256
|
137 |
+
):
|
138 |
+
assert len(image.shape) == 4, f"Image must be a 4D tensor (1, c, h, w), got {image.shape}"
|
139 |
+
window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
|
140 |
+
stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
|
141 |
+
window_size = tuple(window_size)
|
142 |
+
stride = tuple(stride)
|
143 |
+
assert isinstance(window_size, tuple) and len(window_size) == 2 and window_size[0] > 0 and window_size[1] > 0, f"Window size must be a positive integer tuple (h, w), got {window_size}"
|
144 |
+
assert isinstance(stride, tuple) and len(stride) == 2 and stride[0] > 0 and stride[1] > 0, f"Stride must be a positive integer tuple (h, w), got {stride}"
|
145 |
+
assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"Stride must be smaller than window size, got {stride} and {window_size}"
|
146 |
+
|
147 |
+
image_height, image_width = image.shape[-2:]
|
148 |
+
window_height, window_width = window_size
|
149 |
+
assert image_height >= window_height and image_width >= window_width, f"Image size must be larger than window size, got image size {image.shape} and window size {window_size}"
|
150 |
+
stride_height, stride_width = stride
|
151 |
+
|
152 |
+
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
|
153 |
+
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
|
154 |
+
|
155 |
+
if hasattr(model, "block_size"):
|
156 |
+
block_size = model.block_size
|
157 |
+
elif hasattr(model, "module") and hasattr(model.module, "block_size"):
|
158 |
+
block_size = model.module.block_size
|
159 |
+
else:
|
160 |
+
raise ValueError("Model must have block_size attribute")
|
161 |
+
assert window_height % block_size == 0 and window_width % block_size == 0, f"Window size must be divisible by block size, got {window_size} and {block_size}"
|
162 |
+
|
163 |
+
windows = []
|
164 |
+
for i in range(num_rows):
|
165 |
+
for j in range(num_cols):
|
166 |
+
x_start, y_start = i * stride_height, j * stride_width
|
167 |
+
x_end, y_end = x_start + window_height, y_start + window_width
|
168 |
+
if x_end > image_height:
|
169 |
+
x_start, x_end = image_height - window_height, image_height
|
170 |
+
if y_end > image_width:
|
171 |
+
y_start, y_end = image_width - window_width, image_width
|
172 |
+
|
173 |
+
window = image[:, :, x_start:x_end, y_start:y_end]
|
174 |
+
windows.append(window)
|
175 |
+
|
176 |
+
windows = torch.cat(windows, dim=0).to(image.device) # batched windows, shape: (num_windows, c, h, w)
|
177 |
+
|
178 |
+
model.eval()
|
179 |
+
pi_maps, lambda_maps = [], []
|
180 |
+
for i in range(0, len(windows), max_num_windows):
|
181 |
+
with torch.no_grad():
|
182 |
+
image_feats = model.backbone(windows[i: min(i + max_num_windows, len(windows))])
|
183 |
+
pi_image_feats, lambda_image_feats = model.pi_head(image_feats), model.lambda_head(image_feats)
|
184 |
+
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
185 |
+
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
186 |
+
|
187 |
+
pi_text_feats, lambda_text_feats = model.pi_text_feats, model.lambda_text_feats
|
188 |
+
pi_logit_scale, lambda_logit_scale = model.pi_logit_scale.exp(), model.lambda_logit_scale.exp()
|
189 |
+
|
190 |
+
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
|
191 |
+
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
|
192 |
+
|
193 |
+
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
|
194 |
+
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
|
195 |
+
|
196 |
+
lambda_map = (lambda_logit_map.softmax(dim=1) * model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
|
197 |
+
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
|
198 |
+
|
199 |
+
pi_maps.append(pi_map.cpu().numpy())
|
200 |
+
lambda_maps.append(lambda_map.cpu().numpy())
|
201 |
+
|
202 |
+
# assemble the density map
|
203 |
+
pi_maps = np.concatenate(pi_maps, axis=0) # shape: (num_windows, 1, H, W)
|
204 |
+
lambda_maps = np.concatenate(lambda_maps, axis=0) # shape: (num_windows, 1, H, W)
|
205 |
+
assert pi_maps.shape == lambda_maps.shape, f"pi_maps and lambda_maps must have the same shape, got {pi_maps.shape} and {lambda_maps.shape}"
|
206 |
+
|
207 |
+
pi_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
|
208 |
+
lambda_map = np.zeros((lambda_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
|
209 |
+
count_map = np.zeros((pi_maps.shape[1], image_height // block_size, image_width // block_size), dtype=np.float32)
|
210 |
+
idx = 0
|
211 |
+
for i in range(num_rows):
|
212 |
+
for j in range(num_cols):
|
213 |
+
x_start, y_start = i * stride_height, j * stride_width
|
214 |
+
x_end, y_end = x_start + window_height, y_start + window_width
|
215 |
+
if x_end > image_height:
|
216 |
+
x_start, x_end = image_height - window_height, image_height
|
217 |
+
if y_end > image_width:
|
218 |
+
y_start, y_end = image_width - window_width, image_width
|
219 |
+
|
220 |
+
pi_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += pi_maps[idx, :, :, :]
|
221 |
+
lambda_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += lambda_maps[idx, :, :, :]
|
222 |
+
count_map[:, (x_start // block_size): (x_end // block_size), (y_start // block_size): (y_end // block_size)] += 1.
|
223 |
+
idx += 1
|
224 |
+
|
225 |
+
# average the density map
|
226 |
+
pi_map /= count_map
|
227 |
+
lambda_map /= count_map
|
228 |
+
|
229 |
+
# convert to Tensor and reshape
|
230 |
+
pi_map = torch.from_numpy(pi_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
|
231 |
+
lambda_map = torch.from_numpy(lambda_map).unsqueeze(0) # shape: (1, 1, H // block_size, W // block_size)
|
232 |
+
return pi_map, lambda_map
|
233 |
+
|
234 |
+
|
235 |
+
# -----------------------------
|
236 |
+
# Inference function
|
237 |
+
# -----------------------------
|
238 |
+
@spaces.GPU(duration=120)
|
239 |
+
def predict(image: Image.Image, variant: str, dataset: str, metric: str):
|
240 |
+
"""
|
241 |
+
Given an input image, preprocess it, run the model to obtain a density map,
|
242 |
+
compute the total crowd count, and prepare the density map for display.
|
243 |
+
"""
|
244 |
+
global loaded_model
|
245 |
+
|
246 |
+
if loaded_model is None:
|
247 |
+
|
248 |
+
if dataset == "ShanghaiTech A":
|
249 |
+
dataset_name = "sha"
|
250 |
+
elif dataset == "ShanghaiTech B":
|
251 |
+
dataset_name = "shb"
|
252 |
+
elif dataset == "UCF-QNRF":
|
253 |
+
dataset_name = "qnrf"
|
254 |
+
elif dataset == "NWPU-Crowd":
|
255 |
+
dataset_name = "nwpu"
|
256 |
+
|
257 |
+
weight_path = f"Yiming-M/{variant}/checkpoints/{dataset_name}/best_{metric}.pth"
|
258 |
+
load_model(weight_path)
|
259 |
+
|
260 |
+
loaded_model.to(device)
|
261 |
+
|
262 |
+
# Preprocess the image
|
263 |
+
input_width, input_height = image.size
|
264 |
+
image_tensor = transform(image, dataset_name).to(device) # shape: (1, 3, H, W)
|
265 |
+
|
266 |
+
input_size = loaded_model.input_size
|
267 |
+
image_height, image_width = image_tensor.shape[-2:]
|
268 |
+
aspect_ratio = image_width / image_height
|
269 |
+
if image_height < input_size:
|
270 |
+
new_height = input_size
|
271 |
+
new_width = int(new_height * aspect_ratio)
|
272 |
+
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
|
273 |
+
image_height, image_width = new_height, new_width
|
274 |
+
if image_width < input_size:
|
275 |
+
new_width = input_size
|
276 |
+
new_height = int(new_width / aspect_ratio)
|
277 |
+
image_tensor = F.interpolate(image_tensor, size=(new_height, new_width), mode="bicubic", align_corners=False, antialias=True)
|
278 |
+
image_height, image_width = new_height, new_width
|
279 |
+
|
280 |
+
with torch.no_grad():
|
281 |
+
if hasattr(loaded_model, "num_vpt") and loaded_model.num_vpt > 0: # For ViT models, use sliding window prediction
|
282 |
+
# For ViT models with VPT
|
283 |
+
pi_map, lambda_map = _sliding_window_predict(
|
284 |
+
model=loaded_model,
|
285 |
+
image=image_tensor,
|
286 |
+
window_size=input_size,
|
287 |
+
stride=input_size
|
288 |
+
)
|
289 |
+
|
290 |
+
elif hasattr(loaded_model, "pi_text_feats") and hasattr(loaded_model, "lambda_text_feats") and loaded_model.pi_text_feats is not None and loaded_model.lambda_text_feats is not None: # For other CLIP-based models
|
291 |
+
image_feats = loaded_model.backbone(image_tensor)
|
292 |
+
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
293 |
+
pi_image_feats, lambda_image_feats = loaded_model.pi_head(image_feats), loaded_model.lambda_head(image_feats)
|
294 |
+
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
295 |
+
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
296 |
+
|
297 |
+
pi_text_feats, lambda_text_feats = loaded_model.pi_text_feats, loaded_model.lambda_text_feats
|
298 |
+
pi_logit_scale, lambda_logit_scale = loaded_model.pi_logit_scale.exp(), loaded_model.lambda_logit_scale.exp()
|
299 |
+
|
300 |
+
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
|
301 |
+
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
|
302 |
+
|
303 |
+
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
|
304 |
+
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
|
305 |
+
|
306 |
+
lambda_map = (lambda_logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
|
307 |
+
pi_map = pi_logit_map.softmax(dim=1)[:, 0:1] # (B, 1, H, W)
|
308 |
+
|
309 |
+
else: # For non-CLIP models
|
310 |
+
x = loaded_model.backbone(image_tensor)
|
311 |
+
logit_pi_map = loaded_model.pi_head(x) # shape: (B, 2, H, W)
|
312 |
+
logit_map = loaded_model.bin_head(x) # shape: (B, C, H, W)
|
313 |
+
lambda_map= (logit_map.softmax(dim=1) * loaded_model.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W)
|
314 |
+
pi_map = logit_pi_map.softmax(dim=1)[:, 0:1] # shape: (B, 1, H, W)
|
315 |
+
|
316 |
+
|
317 |
+
den_map = (1.0 - pi_map) * lambda_map # shape: (B, 1, H, W)
|
318 |
+
count = den_map.sum().item()
|
319 |
+
|
320 |
+
strucrual_zero_map = F.interpolate(
|
321 |
+
pi_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
|
322 |
+
).cpu().squeeze().numpy()
|
323 |
+
|
324 |
+
lambda_map = F.interpolate(
|
325 |
+
lambda_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
|
326 |
+
).cpu().squeeze().numpy()
|
327 |
+
|
328 |
+
den_map = F.interpolate(
|
329 |
+
den_map, size=(input_height, input_width), mode="bilinear", align_corners=False, antialias=True
|
330 |
+
).cpu().squeeze().numpy()
|
331 |
+
|
332 |
+
sampling_zero_map = (1.0 - strucrual_zero_map) * np.exp(-lambda_map)
|
333 |
+
complete_zero_map = strucrual_zero_map + sampling_zero_map
|
334 |
+
|
335 |
+
# Normalize maps for display purposes
|
336 |
+
def normalize_map(x: np.ndarray) -> np.ndarray:
|
337 |
+
""" Normalize the map to [0, 1] range for visualization. """
|
338 |
+
x_min = np.min(x)
|
339 |
+
x_max = np.max(x)
|
340 |
+
if x_max - x_min < EPS:
|
341 |
+
return np.zeros_like(x)
|
342 |
+
return (x - x_min) / (x_max - x_min + EPS)
|
343 |
+
|
344 |
+
strucrual_zero_map = normalize_map(strucrual_zero_map)
|
345 |
+
sampling_zero_map = normalize_map(sampling_zero_map)
|
346 |
+
lambda_map = normalize_map(lambda_map)
|
347 |
+
den_map = normalize_map(den_map)
|
348 |
+
complete_zero_map = normalize_map(complete_zero_map)
|
349 |
+
|
350 |
+
# Apply a colormap (e.g., 'jet') to get an RGBA image
|
351 |
+
colormap = cm.get_cmap("jet")
|
352 |
+
|
353 |
+
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
|
354 |
+
den_map = (colormap(den_map) * 255).astype(np.uint8)
|
355 |
+
strucrual_zero_map = (colormap(strucrual_zero_map) * 255).astype(np.uint8)
|
356 |
+
sampling_zero_map = (colormap(sampling_zero_map) * 255).astype(np.uint8)
|
357 |
+
lambda_map = (colormap(lambda_map) * 255).astype(np.uint8)
|
358 |
+
complete_zero_map = (colormap(complete_zero_map) * 255).astype(np.uint8)
|
359 |
+
|
360 |
+
# Convert to PIL images
|
361 |
+
den_map = Image.fromarray(den_map).convert("RGBA")
|
362 |
+
strucrual_zero_map = Image.fromarray(strucrual_zero_map).convert("RGBA")
|
363 |
+
sampling_zero_map = Image.fromarray(sampling_zero_map).convert("RGBA")
|
364 |
+
lambda_map = Image.fromarray(lambda_map).convert("RGBA")
|
365 |
+
complete_zero_map = Image.fromarray(complete_zero_map).convert("RGBA")
|
366 |
+
|
367 |
+
# Ensure the original image is in RGBA format.
|
368 |
+
image_rgba = image.convert("RGBA")
|
369 |
+
|
370 |
+
den_map = Image.blend(image_rgba, den_map, alpha=alpha)
|
371 |
+
strucrual_zero_map = Image.blend(image_rgba, strucrual_zero_map, alpha=alpha)
|
372 |
+
sampling_zero_map = Image.blend(image_rgba, sampling_zero_map, alpha=alpha)
|
373 |
+
lambda_map = Image.blend(image_rgba, lambda_map, alpha=alpha)
|
374 |
+
complete_zero_map = Image.blend(image_rgba, complete_zero_map, alpha=alpha)
|
375 |
+
|
376 |
+
return image, strucrual_zero_map, sampling_zero_map, complete_zero_map, lambda_map, den_map, f"Predicted Count: {count:.2f}"
|
377 |
+
|
378 |
+
|
379 |
+
# -----------------------------
|
380 |
+
# Build Gradio Interface using Blocks for a two-column layout
|
381 |
+
# -----------------------------
|
382 |
+
with gr.Blocks() as demo:
|
383 |
+
gr.Markdown("# Crowd Counting by ZIP")
|
384 |
+
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
|
385 |
+
|
386 |
+
with gr.Row():
|
387 |
+
with gr.Column():
|
388 |
+
# Dropdown for model variant
|
389 |
+
variant_dropdown = gr.Dropdown(
|
390 |
+
choices=list(pretrained_datasets.keys()),
|
391 |
+
value="ZIP-B",
|
392 |
+
label="Select Model Variant"
|
393 |
+
)
|
394 |
+
|
395 |
+
# Dropdown for pretrained dataset, dynamically updated based on variant
|
396 |
+
dataset_dropdown = gr.Dropdown(
|
397 |
+
choices=pretrained_datasets["ZIP-B"],
|
398 |
+
value=pretrained_datasets["ZIP-B"][0],
|
399 |
+
label="Select Pretrained Dataset"
|
400 |
+
)
|
401 |
+
|
402 |
+
# Dropdown for metric, always the same choices
|
403 |
+
metric_dropdown = gr.Dropdown(
|
404 |
+
choices=["mae", "rmse", "nae"],
|
405 |
+
value="mae",
|
406 |
+
label="Select Best Metric"
|
407 |
+
)
|
408 |
+
|
409 |
+
# Update dataset choices when variant changes
|
410 |
+
def update_dataset(variant):
|
411 |
+
choices = pretrained_datasets[variant]
|
412 |
+
return gr.Dropdown.update(
|
413 |
+
choices=choices,
|
414 |
+
value=choices[0]
|
415 |
+
)
|
416 |
+
|
417 |
+
variant_dropdown.change(
|
418 |
+
fn=update_dataset,
|
419 |
+
inputs=variant_dropdown,
|
420 |
+
outputs=dataset_dropdown
|
421 |
+
)
|
422 |
+
input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
|
423 |
+
submit_btn = gr.Button("Predict")
|
424 |
+
|
425 |
+
with gr.Column():
|
426 |
+
output_den_map = gr.Image(label="Predicted Density Map", type="pil")
|
427 |
+
output_structural_zero_map = gr.Image(label="Structural Zero Map", type="pil")
|
428 |
+
output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
|
429 |
+
output_lambda_map = gr.Image(label="Lambda Map", type="pil")
|
430 |
+
output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
|
431 |
+
|
432 |
+
output_text = gr.Textbox(label="Total Count")
|
433 |
+
|
434 |
+
submit_btn.click(
|
435 |
+
fn=predict,
|
436 |
+
inputs=[input_img, variant_dropdown, dataset_dropdown, metric_dropdown],
|
437 |
+
outputs=[input_img, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map, output_lambda_map, output_den_map, output_text]
|
438 |
+
)
|
439 |
+
|
440 |
+
gr.Examples(
|
441 |
+
examples=[
|
442 |
+
["example1.jpg"],
|
443 |
+
["example2.jpg"],
|
444 |
+
["example3.jpg"],
|
445 |
+
["example4.jpg"],
|
446 |
+
["example5.jpg"],
|
447 |
+
["example6.jpg"],
|
448 |
+
["example7.jpg"],
|
449 |
+
["example8.jpg"],
|
450 |
+
["example9.jpg"],
|
451 |
+
["example10.jpg"],
|
452 |
+
["example11.jpg"],
|
453 |
+
["example12.jpg"]
|
454 |
+
],
|
455 |
+
inputs=input_img,
|
456 |
+
label="Try an example"
|
457 |
+
)
|
458 |
+
|
459 |
+
demo.launch()
|
models/__init__.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
from typing import List, Tuple, Optional, Union, Dict
|
3 |
+
|
4 |
+
from .ebc import _ebc, EBC
|
5 |
+
from .clip_ebc import _clip_ebc, CLIP_EBC
|
6 |
+
|
7 |
+
|
8 |
+
def get_model(
|
9 |
+
model_info_path: str,
|
10 |
+
model_name: Optional[str] = None,
|
11 |
+
block_size: Optional[int] = None,
|
12 |
+
bins: Optional[List[Tuple[float, float]]] = None,
|
13 |
+
bin_centers: Optional[List[float]] = None,
|
14 |
+
zero_inflated: Optional[bool] = True,
|
15 |
+
# parameters for CLIP_EBC
|
16 |
+
clip_weight_name: Optional[str] = None,
|
17 |
+
num_vpt: Optional[int] = None,
|
18 |
+
vpt_drop: Optional[float] = None,
|
19 |
+
input_size: Optional[int] = None,
|
20 |
+
adapter: bool = False,
|
21 |
+
adapter_reduction: Optional[int] = None,
|
22 |
+
lora: bool = False,
|
23 |
+
lora_rank: Optional[int] = None,
|
24 |
+
lora_alpha: Optional[int] = None,
|
25 |
+
lora_dropout: Optional[float] = None,
|
26 |
+
norm: str = "none",
|
27 |
+
act: str = "none",
|
28 |
+
text_prompts: Optional[List[str]] = None
|
29 |
+
) -> Union[EBC, CLIP_EBC]:
|
30 |
+
if os.path.exists(model_info_path):
|
31 |
+
model_info = torch.load(model_info_path, map_location="cpu", weights_only=False)
|
32 |
+
|
33 |
+
model_name = model_info["config"]["model_name"]
|
34 |
+
block_size = model_info["config"]["block_size"]
|
35 |
+
bins = model_info["config"]["bins"]
|
36 |
+
bin_centers = model_info["config"]["bin_centers"]
|
37 |
+
zero_inflated = model_info["config"]["zero_inflated"]
|
38 |
+
|
39 |
+
clip_weight_name = model_info["config"].get("clip_weight_name", None)
|
40 |
+
|
41 |
+
num_vpt = model_info["config"].get("num_vpt", None)
|
42 |
+
vpt_drop = model_info["config"].get("vpt_drop", None)
|
43 |
+
|
44 |
+
|
45 |
+
adapter = model_info["config"].get("adapter", False)
|
46 |
+
adapter_reduction = model_info["config"].get("adapter_reduction", None)
|
47 |
+
|
48 |
+
lora = model_info["config"].get("lora", False)
|
49 |
+
lora_rank = model_info["config"].get("lora_rank", None)
|
50 |
+
lora_alpha = model_info["config"].get("lora_alpha", None)
|
51 |
+
lora_dropout = model_info["config"].get("lora_dropout", None)
|
52 |
+
|
53 |
+
input_size = model_info["config"].get("input_size", None)
|
54 |
+
text_prompts = model_info["config"].get("text_prompts", None)
|
55 |
+
|
56 |
+
norm = model_info["config"].get("norm", "none")
|
57 |
+
act = model_info["config"].get("act", "none")
|
58 |
+
|
59 |
+
weights = model_info["weights"]
|
60 |
+
|
61 |
+
else:
|
62 |
+
assert model_name is not None, "model_name should be provided if model_info_path is not provided"
|
63 |
+
assert block_size is not None, "block_size should be provided"
|
64 |
+
assert bins is not None, "bins should be provided"
|
65 |
+
assert bin_centers is not None, "bin_centers should be provided"
|
66 |
+
weights = None
|
67 |
+
|
68 |
+
if "ViT" in model_name:
|
69 |
+
assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}"
|
70 |
+
assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}"
|
71 |
+
|
72 |
+
if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"):
|
73 |
+
assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}"
|
74 |
+
model = _clip_ebc(
|
75 |
+
model_name=model_name[5:],
|
76 |
+
weight_name=clip_weight_name,
|
77 |
+
block_size=block_size,
|
78 |
+
bins=bins,
|
79 |
+
bin_centers=bin_centers,
|
80 |
+
zero_inflated=zero_inflated,
|
81 |
+
num_vpt=num_vpt,
|
82 |
+
vpt_drop=vpt_drop,
|
83 |
+
input_size=input_size,
|
84 |
+
adapter=adapter,
|
85 |
+
adapter_reduction=adapter_reduction,
|
86 |
+
lora=lora,
|
87 |
+
lora_rank=lora_rank,
|
88 |
+
lora_alpha=lora_alpha,
|
89 |
+
lora_dropout=lora_dropout,
|
90 |
+
text_prompts=text_prompts,
|
91 |
+
norm=norm,
|
92 |
+
act=act
|
93 |
+
)
|
94 |
+
model_config = {
|
95 |
+
"model_name": model_name,
|
96 |
+
"block_size": block_size,
|
97 |
+
"bins": bins,
|
98 |
+
"bin_centers": bin_centers,
|
99 |
+
"zero_inflated": zero_inflated,
|
100 |
+
"clip_weight_name": clip_weight_name,
|
101 |
+
"num_vpt": num_vpt,
|
102 |
+
"vpt_drop": vpt_drop,
|
103 |
+
"input_size": input_size,
|
104 |
+
"adapter": adapter,
|
105 |
+
"adapter_reduction": adapter_reduction,
|
106 |
+
"lora": lora,
|
107 |
+
"lora_rank": lora_rank,
|
108 |
+
"lora_alpha": lora_alpha,
|
109 |
+
"lora_dropout": lora_dropout,
|
110 |
+
"text_prompts": model.text_prompts,
|
111 |
+
"norm": norm,
|
112 |
+
"act": act
|
113 |
+
}
|
114 |
+
|
115 |
+
else:
|
116 |
+
assert not adapter, "adapter for non-CLIP models is not implemented yet"
|
117 |
+
assert not lora, "lora for non-CLIP models is not implemented yet"
|
118 |
+
model = _ebc(
|
119 |
+
model_name=model_name,
|
120 |
+
block_size=block_size,
|
121 |
+
bins=bins,
|
122 |
+
bin_centers=bin_centers,
|
123 |
+
zero_inflated=zero_inflated,
|
124 |
+
num_vpt=num_vpt,
|
125 |
+
vpt_drop=vpt_drop,
|
126 |
+
input_size=input_size,
|
127 |
+
norm=norm,
|
128 |
+
act=act
|
129 |
+
)
|
130 |
+
model_config = {
|
131 |
+
"model_name": model_name,
|
132 |
+
"block_size": block_size,
|
133 |
+
"bins": bins,
|
134 |
+
"bin_centers": bin_centers,
|
135 |
+
"zero_inflated": zero_inflated,
|
136 |
+
"num_vpt": num_vpt,
|
137 |
+
"vpt_drop": vpt_drop,
|
138 |
+
"input_size": input_size,
|
139 |
+
"norm": norm,
|
140 |
+
"act": act
|
141 |
+
}
|
142 |
+
|
143 |
+
model.config = model_config
|
144 |
+
model_info = {"config": model_config, "weights": weights}
|
145 |
+
|
146 |
+
if weights is not None:
|
147 |
+
model.load_state_dict(weights)
|
148 |
+
|
149 |
+
if not os.path.exists(model_info_path):
|
150 |
+
torch.save(model_info, model_info_path)
|
151 |
+
|
152 |
+
return model
|
153 |
+
|
154 |
+
|
155 |
+
__all__ = ["get_model"]
|
models/clip_ebc/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import CLIP_EBC, _clip_ebc
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"CLIP_EBC",
|
6 |
+
"_clip_ebc",
|
7 |
+
]
|
models/clip_ebc/convnext.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
import open_clip
|
3 |
+
from peft import get_peft_model, LoraConfig
|
4 |
+
|
5 |
+
from ..utils import ConvRefine, ConvAdapter
|
6 |
+
from ..utils import ConvUpsample, _get_norm_layer, _get_activation
|
7 |
+
|
8 |
+
|
9 |
+
convnext_names_and_weights = {
|
10 |
+
"convnext_base": ["laion400m_s13b_b51k"], # 107.49M
|
11 |
+
"convnext_base_w": ["laion2b_s13b_b82k", "laion2b_s13b_b82k_augreg", "laion_aesthetic_s13b_b82k"], # 107.75M
|
12 |
+
"convnext_base_w_320": ["laion_aesthetic_s13b_b82k", "laion_aesthetic_s13b_b82k_augreg"], # 107.75M
|
13 |
+
"convnext_large_d": ["laion2b_s26b_b102k_augreg"], # 217.46M
|
14 |
+
"convnext_large_d_320": ["laion2b_s29b_b131k_ft", "laion2b_s29b_b131k_ft_soup"], # 217.46M
|
15 |
+
"convnext_xxlarge": ["laion2b_s34b_b82k_augreg", "laion2b_s34b_b82k_augreg_rewind", "laion2b_s34b_b82k_augreg_soup"] # 896.88M
|
16 |
+
}
|
17 |
+
|
18 |
+
refiner_channels = {
|
19 |
+
"convnext_base": 1024,
|
20 |
+
"convnext_base_w": 1024,
|
21 |
+
"convnext_base_w_320": 1024,
|
22 |
+
"convnext_large_d": 1536,
|
23 |
+
"convnext_large_d_320": 1536,
|
24 |
+
"convnext_xxlarge": 3072,
|
25 |
+
}
|
26 |
+
|
27 |
+
refiner_groups = {
|
28 |
+
"convnext_base": 1,
|
29 |
+
"convnext_base_w": 1,
|
30 |
+
"convnext_base_w_320": 1,
|
31 |
+
"convnext_large_d": refiner_channels["convnext_large_d"] // 512, # 3
|
32 |
+
"convnext_large_d_320": refiner_channels["convnext_large_d_320"] // 512, # 3
|
33 |
+
"convnext_xxlarge": refiner_channels["convnext_xxlarge"] // 512, # 6
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class ConvNeXt(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model_name: str,
|
42 |
+
weight_name: str,
|
43 |
+
block_size: int = 16,
|
44 |
+
adapter: bool = False,
|
45 |
+
adapter_reduction: int = 4,
|
46 |
+
norm: str = "none",
|
47 |
+
act: str = "none"
|
48 |
+
) -> None:
|
49 |
+
super(ConvNeXt, self).__init__()
|
50 |
+
assert model_name in convnext_names_and_weights, f"Model name should be one of {list(convnext_names_and_weights.keys())}, but got {model_name}."
|
51 |
+
assert weight_name in convnext_names_and_weights[model_name], f"Pretrained should be one of {convnext_names_and_weights[model_name]}, but got {weight_name}."
|
52 |
+
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
|
53 |
+
self.model_name, self.weight_name = model_name, weight_name
|
54 |
+
self.block_size = block_size
|
55 |
+
|
56 |
+
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
|
57 |
+
|
58 |
+
self.adapter = adapter
|
59 |
+
if adapter:
|
60 |
+
self.adapter_reduction = adapter_reduction
|
61 |
+
for param in model.parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
|
64 |
+
self.stem = model.trunk.stem
|
65 |
+
self.depth = len(model.trunk.stages)
|
66 |
+
for idx, stage in enumerate(model.trunk.stages):
|
67 |
+
setattr(self, f"stage{idx}", stage)
|
68 |
+
if adapter:
|
69 |
+
setattr(self, f"adapter{idx}", ConvAdapter(
|
70 |
+
in_channels=stage.blocks[-1].mlp.fc2.out_features,
|
71 |
+
bottleneck_channels=stage.blocks[-1].mlp.fc2.out_features // adapter_reduction,
|
72 |
+
) if idx < self.depth - 1 else nn.Identity()) # No adapter for the last stage
|
73 |
+
|
74 |
+
if self.model_name in ["convnext_base", "convnext_base_w", "convnext_base_w_320", "convnext_xxlarge"]:
|
75 |
+
self.in_features, self.out_features = model.head.proj.in_features, model.head.proj.out_features
|
76 |
+
else: # "convnext_large_d", "convnext_large_d_320":
|
77 |
+
self.in_features, self.out_features = model.head.mlp.fc1.in_features, model.head.mlp.fc2.out_features
|
78 |
+
|
79 |
+
if norm == "bn":
|
80 |
+
norm_layer = nn.BatchNorm2d
|
81 |
+
elif norm == "ln":
|
82 |
+
norm_layer = nn.LayerNorm
|
83 |
+
else:
|
84 |
+
norm_layer = _get_norm_layer(model)
|
85 |
+
|
86 |
+
if act == "relu":
|
87 |
+
activation = nn.ReLU(inplace=True)
|
88 |
+
elif act == "gelu":
|
89 |
+
activation = nn.GELU()
|
90 |
+
else:
|
91 |
+
activation = _get_activation(model)
|
92 |
+
|
93 |
+
if block_size == 32:
|
94 |
+
self.refiner = ConvRefine(
|
95 |
+
in_channels=self.in_features,
|
96 |
+
out_channels=self.in_features,
|
97 |
+
norm_layer=norm_layer,
|
98 |
+
activation=activation,
|
99 |
+
groups=refiner_groups[self.model_name],
|
100 |
+
)
|
101 |
+
elif block_size == 16:
|
102 |
+
self.refiner = ConvUpsample(
|
103 |
+
in_channels=self.in_features,
|
104 |
+
out_channels=self.in_features,
|
105 |
+
norm_layer=norm_layer,
|
106 |
+
activation=activation,
|
107 |
+
groups=refiner_groups[self.model_name],
|
108 |
+
)
|
109 |
+
else: # block_size == 8
|
110 |
+
self.refiner = nn.Sequential(
|
111 |
+
ConvUpsample(
|
112 |
+
in_channels=self.in_features,
|
113 |
+
out_channels=self.in_features,
|
114 |
+
norm_layer=norm_layer,
|
115 |
+
activation=activation,
|
116 |
+
groups=refiner_groups[self.model_name],
|
117 |
+
),
|
118 |
+
ConvUpsample(
|
119 |
+
in_channels=self.in_features,
|
120 |
+
out_channels=self.in_features,
|
121 |
+
norm_layer=norm_layer,
|
122 |
+
activation=activation,
|
123 |
+
groups=refiner_groups[self.model_name],
|
124 |
+
),
|
125 |
+
)
|
126 |
+
|
127 |
+
def train(self, mode: bool = True):
|
128 |
+
if self.adapter and mode:
|
129 |
+
# training:
|
130 |
+
self.stem.eval()
|
131 |
+
|
132 |
+
for idx in range(self.depth):
|
133 |
+
getattr(self, f"stage{idx}").eval()
|
134 |
+
getattr(self, f"adapter{idx}").train()
|
135 |
+
|
136 |
+
self.refiner.train()
|
137 |
+
|
138 |
+
else:
|
139 |
+
# evaluation:
|
140 |
+
for module in self.children():
|
141 |
+
module.train(mode)
|
142 |
+
|
143 |
+
def forward(self, x: Tensor) -> Tensor:
|
144 |
+
x = self.stem(x)
|
145 |
+
|
146 |
+
for idx in range(self.depth):
|
147 |
+
x = getattr(self, f"stage{idx}")(x)
|
148 |
+
if self.adapter:
|
149 |
+
x = getattr(self, f"adapter{idx}")(x)
|
150 |
+
|
151 |
+
x = self.refiner(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
def _convnext(
|
156 |
+
model_name: str,
|
157 |
+
weight_name: str,
|
158 |
+
block_size: int = 16,
|
159 |
+
adapter: bool = False,
|
160 |
+
adapter_reduction: int = 4,
|
161 |
+
lora: bool = False,
|
162 |
+
lora_rank: int = 16,
|
163 |
+
lora_alpha: float = 32.0,
|
164 |
+
lora_dropout: float = 0.1,
|
165 |
+
norm: str = "none",
|
166 |
+
act: str = "none"
|
167 |
+
) -> ConvNeXt:
|
168 |
+
assert not (lora and adapter), "Lora and adapter cannot be used together."
|
169 |
+
model = ConvNeXt(
|
170 |
+
model_name=model_name,
|
171 |
+
weight_name=weight_name,
|
172 |
+
block_size=block_size,
|
173 |
+
adapter=adapter,
|
174 |
+
adapter_reduction=adapter_reduction,
|
175 |
+
norm=norm,
|
176 |
+
act=act
|
177 |
+
)
|
178 |
+
|
179 |
+
if lora:
|
180 |
+
target_modules = []
|
181 |
+
for name, module in model.named_modules():
|
182 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)) and "refiner" not in name:
|
183 |
+
target_modules.append(name)
|
184 |
+
|
185 |
+
lora_config = LoraConfig(
|
186 |
+
r=lora_rank,
|
187 |
+
lora_alpha=lora_alpha,
|
188 |
+
lora_dropout=lora_dropout,
|
189 |
+
bias="none",
|
190 |
+
target_modules=target_modules,
|
191 |
+
)
|
192 |
+
model = get_peft_model(model, lora_config)
|
193 |
+
|
194 |
+
# Unfreeze refiner
|
195 |
+
for name, module in model.named_modules():
|
196 |
+
if "refiner" in name:
|
197 |
+
module.requires_grad_(True)
|
198 |
+
|
199 |
+
return model
|
models/clip_ebc/mobileclip.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
import open_clip
|
3 |
+
from peft import get_peft_model, LoraConfig
|
4 |
+
|
5 |
+
from ..utils import ConvRefine, ConvUpsample, ConvAdapter
|
6 |
+
from ..utils import _get_norm_layer, _get_activation
|
7 |
+
|
8 |
+
|
9 |
+
mobileclip_names_and_weights = {
|
10 |
+
"MobileCLIP-S1": ["datacompdr"],
|
11 |
+
"MobileCLIP-S2": ["datacompdr"],
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
refiner_channels = {
|
16 |
+
"MobileCLIP-S1": 1024,
|
17 |
+
"MobileCLIP-S2": 1280,
|
18 |
+
}
|
19 |
+
|
20 |
+
refiner_groups = {
|
21 |
+
"MobileCLIP-S1": 2,
|
22 |
+
"MobileCLIP-S2": 2,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class MobileCLIP(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model_name: str,
|
30 |
+
weight_name: str,
|
31 |
+
block_size: int = 16,
|
32 |
+
adapter: bool = False,
|
33 |
+
adapter_reduction: int = 4,
|
34 |
+
norm: str = "none",
|
35 |
+
act: str = "none"
|
36 |
+
) -> None:
|
37 |
+
super().__init__()
|
38 |
+
assert model_name in mobileclip_names_and_weights, f"Model name should be one of {list(mobileclip_names_and_weights.keys())}, but got {model_name}."
|
39 |
+
assert weight_name in mobileclip_names_and_weights[model_name], f"Pretrained should be one of {mobileclip_names_and_weights[model_name]}, but got {weight_name}."
|
40 |
+
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
|
41 |
+
self.model_name, self.weight_name = model_name, weight_name
|
42 |
+
self.block_size = block_size
|
43 |
+
|
44 |
+
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
|
45 |
+
|
46 |
+
self.adapter = adapter
|
47 |
+
if adapter:
|
48 |
+
for param in model.parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
|
51 |
+
self.stem = model.trunk.stem
|
52 |
+
self.stages = model.trunk.stages
|
53 |
+
|
54 |
+
self.depth = len(model.trunk.stages)
|
55 |
+
for idx, stage in enumerate(model.trunk.stages):
|
56 |
+
if adapter:
|
57 |
+
setattr(self, f"adapter{idx}", ConvAdapter(
|
58 |
+
in_channels=stage.blocks[-1].mlp.fc2.out_channels,
|
59 |
+
bottleneck_channels=stage.blocks[-1].mlp.fc2.out_channels // adapter_reduction,
|
60 |
+
))
|
61 |
+
|
62 |
+
self.final_conv = model.trunk.final_conv
|
63 |
+
|
64 |
+
self.in_features, self.out_features = model.trunk.head.fc.in_features, model.trunk.head.fc.out_features
|
65 |
+
|
66 |
+
# refine_block = LightConvRefine if model_name == "MobileCLIP-S1" else ConvRefine
|
67 |
+
# upsample_block = LightConvUpsample if model_name == "MobileCLIP-S1" else ConvUpsample
|
68 |
+
|
69 |
+
if norm == "bn":
|
70 |
+
norm_layer = nn.BatchNorm2d
|
71 |
+
elif norm == "ln":
|
72 |
+
norm_layer = nn.LayerNorm
|
73 |
+
else:
|
74 |
+
norm_layer = _get_norm_layer(model)
|
75 |
+
|
76 |
+
if act == "relu":
|
77 |
+
activation = nn.ReLU(inplace=True)
|
78 |
+
elif act == "gelu":
|
79 |
+
activation = nn.GELU()
|
80 |
+
else:
|
81 |
+
activation = _get_activation(model)
|
82 |
+
|
83 |
+
if block_size == 32:
|
84 |
+
self.refiner = ConvRefine(
|
85 |
+
in_channels=self.in_features,
|
86 |
+
out_channels=self.in_features,
|
87 |
+
norm_layer=norm_layer,
|
88 |
+
activation=activation,
|
89 |
+
groups=refiner_groups[model_name],
|
90 |
+
)
|
91 |
+
elif block_size == 16:
|
92 |
+
self.refiner = ConvUpsample(
|
93 |
+
in_channels=self.in_features,
|
94 |
+
out_channels=self.in_features,
|
95 |
+
norm_layer=norm_layer,
|
96 |
+
activation=activation,
|
97 |
+
groups=refiner_groups[self.model_name],
|
98 |
+
)
|
99 |
+
else: # block_size == 8
|
100 |
+
self.refiner = nn.Sequential(
|
101 |
+
ConvUpsample(
|
102 |
+
in_channels=self.in_features,
|
103 |
+
out_channels=self.in_features,
|
104 |
+
norm_layer=norm_layer,
|
105 |
+
activation=activation,
|
106 |
+
groups=refiner_groups[self.model_name],
|
107 |
+
),
|
108 |
+
ConvUpsample(
|
109 |
+
in_channels=self.in_features,
|
110 |
+
out_channels=self.in_features,
|
111 |
+
norm_layer=norm_layer,
|
112 |
+
activation=activation,
|
113 |
+
groups=refiner_groups[self.model_name],
|
114 |
+
),
|
115 |
+
)
|
116 |
+
|
117 |
+
def train(self, mode: bool = True):
|
118 |
+
if self.adapter and mode:
|
119 |
+
# training:
|
120 |
+
self.stem.eval()
|
121 |
+
|
122 |
+
for idx in range(self.depth):
|
123 |
+
getattr(self, f"stage{idx}").eval()
|
124 |
+
getattr(self, f"adapter{idx}").train()
|
125 |
+
|
126 |
+
self.final_conv.eval()
|
127 |
+
self.refiner.train()
|
128 |
+
|
129 |
+
else:
|
130 |
+
# evaluation:
|
131 |
+
for module in self.children():
|
132 |
+
module.train(mode)
|
133 |
+
|
134 |
+
def forward(self, x: Tensor) -> Tensor:
|
135 |
+
x = self.stem(x)
|
136 |
+
|
137 |
+
for idx in range(self.depth):
|
138 |
+
x = self.stages[idx](x)
|
139 |
+
if self.adapter:
|
140 |
+
x = getattr(self, f"adapter{idx}")(x)
|
141 |
+
|
142 |
+
x = self.final_conv(x)
|
143 |
+
|
144 |
+
x = self.refiner(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
def _mobileclip(
|
149 |
+
model_name: str,
|
150 |
+
weight_name: str,
|
151 |
+
block_size: int = 16,
|
152 |
+
adapter: bool = False,
|
153 |
+
adapter_reduction: int = 4,
|
154 |
+
lora: bool = False,
|
155 |
+
lora_rank: int = 16,
|
156 |
+
lora_alpha: float = 32.0,
|
157 |
+
lora_dropout: float = 0.1,
|
158 |
+
norm: str = "none",
|
159 |
+
act: str = "none"
|
160 |
+
) -> MobileCLIP:
|
161 |
+
assert not (lora and adapter), "Lora and adapter cannot be used together."
|
162 |
+
model = MobileCLIP(
|
163 |
+
model_name=model_name,
|
164 |
+
weight_name=weight_name,
|
165 |
+
block_size=block_size,
|
166 |
+
adapter=adapter,
|
167 |
+
adapter_reduction=adapter_reduction,
|
168 |
+
norm=norm,
|
169 |
+
act=act
|
170 |
+
)
|
171 |
+
|
172 |
+
if lora:
|
173 |
+
target_modules = []
|
174 |
+
for name, module in model.named_modules():
|
175 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
176 |
+
target_modules.append(name)
|
177 |
+
|
178 |
+
lora_config = LoraConfig(
|
179 |
+
r=lora_rank,
|
180 |
+
lora_alpha=lora_alpha,
|
181 |
+
lora_dropout=lora_dropout,
|
182 |
+
bias="none",
|
183 |
+
target_modules=target_modules,
|
184 |
+
)
|
185 |
+
model = get_peft_model(model, lora_config)
|
186 |
+
|
187 |
+
# Unfreeze the BN layers
|
188 |
+
for name, module in model.named_modules() and "refiner" not in name:
|
189 |
+
if isinstance(module, nn.BatchNorm2d):
|
190 |
+
module.requires_grad_(True)
|
191 |
+
|
192 |
+
# Unfreeze refiner
|
193 |
+
for name, module in model.named_modules():
|
194 |
+
if "refiner" in name:
|
195 |
+
module.requires_grad_(True)
|
196 |
+
|
197 |
+
return model
|
models/clip_ebc/model.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Optional, Dict, Tuple
|
6 |
+
from copy import deepcopy
|
7 |
+
|
8 |
+
from .vit import vit_names_and_weights, _vit
|
9 |
+
from .convnext import convnext_names_and_weights, _convnext
|
10 |
+
from .resnet import resnet_names_and_weights, _resnet
|
11 |
+
from .mobileclip import mobileclip_names_and_weights, _mobileclip
|
12 |
+
|
13 |
+
from .utils import encode_text, optimize_text_prompts
|
14 |
+
from ..utils import conv1x1
|
15 |
+
|
16 |
+
supported_models_and_weights = deepcopy(vit_names_and_weights)
|
17 |
+
supported_models_and_weights.update(convnext_names_and_weights)
|
18 |
+
supported_models_and_weights.update(resnet_names_and_weights)
|
19 |
+
supported_models_and_weights.update(mobileclip_names_and_weights)
|
20 |
+
|
21 |
+
|
22 |
+
class CLIP_EBC(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
model_name: str,
|
26 |
+
weight_name: str,
|
27 |
+
block_size: Optional[int] = None,
|
28 |
+
bins: Optional[List[Tuple[float, float]]] = None,
|
29 |
+
bin_centers: Optional[List[float]] = None,
|
30 |
+
zero_inflated: Optional[bool] = True,
|
31 |
+
num_vpt: Optional[int] = None,
|
32 |
+
vpt_drop: Optional[float] = None,
|
33 |
+
input_size: Optional[int] = None,
|
34 |
+
adapter: Optional[bool] = False,
|
35 |
+
adapter_reduction: Optional[int] = None,
|
36 |
+
lora: Optional[bool] = False,
|
37 |
+
lora_rank: Optional[int] = None,
|
38 |
+
lora_alpha: Optional[float] = None,
|
39 |
+
lora_dropout: Optional[float] = None,
|
40 |
+
text_prompts: Optional[Dict[str, List[str]]] = None,
|
41 |
+
norm: Optional[str] = "none",
|
42 |
+
act: Optional[str] = "none",
|
43 |
+
) -> None:
|
44 |
+
super().__init__()
|
45 |
+
if "mobileclip" in model_name.lower() or "vit" in model_name.lower():
|
46 |
+
model_name = model_name.replace("_", "-")
|
47 |
+
assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}."
|
48 |
+
assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}."
|
49 |
+
assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
|
50 |
+
assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
|
51 |
+
assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
|
52 |
+
bins = [(float(b[0]), float(b[1])) for b in bins]
|
53 |
+
assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
|
54 |
+
|
55 |
+
self.model_name = model_name
|
56 |
+
self.weight_name = weight_name
|
57 |
+
self.block_size = block_size
|
58 |
+
self.bins = bins
|
59 |
+
self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
|
60 |
+
self.zero_inflated = zero_inflated
|
61 |
+
self.text_prompts = text_prompts
|
62 |
+
|
63 |
+
# Image encoder
|
64 |
+
if model_name in vit_names_and_weights:
|
65 |
+
assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
|
66 |
+
vpt_drop = 0. if vpt_drop is None else vpt_drop
|
67 |
+
self.backbone = _vit(
|
68 |
+
model_name=model_name,
|
69 |
+
weight_name=weight_name,
|
70 |
+
num_vpt=num_vpt,
|
71 |
+
vpt_drop=vpt_drop,
|
72 |
+
block_size=block_size,
|
73 |
+
adapter=adapter,
|
74 |
+
adapter_reduction=adapter_reduction,
|
75 |
+
lora=lora,
|
76 |
+
lora_rank=lora_rank,
|
77 |
+
lora_alpha=lora_alpha,
|
78 |
+
lora_dropout=lora_dropout,
|
79 |
+
input_size=(input_size, input_size),
|
80 |
+
norm=norm,
|
81 |
+
act=act
|
82 |
+
)
|
83 |
+
elif model_name in convnext_names_and_weights:
|
84 |
+
self.backbone = _convnext(
|
85 |
+
model_name=model_name,
|
86 |
+
weight_name=weight_name,
|
87 |
+
block_size=block_size,
|
88 |
+
adapter=adapter,
|
89 |
+
adapter_reduction=adapter_reduction,
|
90 |
+
lora=lora,
|
91 |
+
lora_rank=lora_rank,
|
92 |
+
lora_alpha=lora_alpha,
|
93 |
+
lora_dropout=lora_dropout,
|
94 |
+
norm=norm,
|
95 |
+
act=act
|
96 |
+
)
|
97 |
+
elif model_name in resnet_names_and_weights:
|
98 |
+
self.backbone = _resnet(
|
99 |
+
model_name=model_name,
|
100 |
+
weight_name=weight_name,
|
101 |
+
block_size=block_size,
|
102 |
+
adapter=adapter,
|
103 |
+
adapter_reduction=adapter_reduction,
|
104 |
+
lora=lora,
|
105 |
+
lora_rank=lora_rank,
|
106 |
+
lora_alpha=lora_alpha,
|
107 |
+
lora_dropout=lora_dropout,
|
108 |
+
norm=norm,
|
109 |
+
act=act
|
110 |
+
)
|
111 |
+
elif model_name in mobileclip_names_and_weights:
|
112 |
+
self.backbone = _mobileclip(
|
113 |
+
model_name=model_name,
|
114 |
+
weight_name=weight_name,
|
115 |
+
block_size=block_size,
|
116 |
+
adapter=adapter,
|
117 |
+
adapter_reduction=adapter_reduction,
|
118 |
+
lora=lora,
|
119 |
+
lora_rank=lora_rank,
|
120 |
+
lora_alpha=lora_alpha,
|
121 |
+
lora_dropout=lora_dropout,
|
122 |
+
norm=norm,
|
123 |
+
act=act
|
124 |
+
)
|
125 |
+
|
126 |
+
self._build_text_feats()
|
127 |
+
self._build_head()
|
128 |
+
|
129 |
+
def _build_text_feats(self) -> None:
|
130 |
+
model_name, weight_name = self.model_name, self.weight_name
|
131 |
+
text_prompts = self.text_prompts
|
132 |
+
|
133 |
+
if text_prompts is None:
|
134 |
+
bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value
|
135 |
+
if self.zero_inflated: # separate 0 from the rest
|
136 |
+
assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}."
|
137 |
+
bins_pi = [0, (1, float("inf"))]
|
138 |
+
bins_lambda = bins[1:]
|
139 |
+
pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi)
|
140 |
+
lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda)
|
141 |
+
self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts}
|
142 |
+
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
|
143 |
+
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
|
144 |
+
pi_text_feats.requires_grad = False
|
145 |
+
lambda_text_feats.requires_grad = False
|
146 |
+
self.register_buffer("pi_text_feats", pi_text_feats)
|
147 |
+
self.register_buffer("lambda_text_feats", lambda_text_feats)
|
148 |
+
|
149 |
+
else:
|
150 |
+
text_prompts = optimize_text_prompts(model_name, weight_name, bins)
|
151 |
+
self.text_prompts = text_prompts
|
152 |
+
text_feats = encode_text(model_name, weight_name, text_prompts)
|
153 |
+
text_feats.requires_grad = False
|
154 |
+
self.register_buffer("text_feats", text_feats)
|
155 |
+
|
156 |
+
else:
|
157 |
+
if self.zero_inflated:
|
158 |
+
assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}."
|
159 |
+
pi_text_prompts = text_prompts["pi"]
|
160 |
+
lambda_text_prompts = text_prompts["lambda"]
|
161 |
+
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
|
162 |
+
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
|
163 |
+
pi_text_feats.requires_grad = False
|
164 |
+
lambda_text_feats.requires_grad = False
|
165 |
+
self.register_buffer("pi_text_feats", pi_text_feats)
|
166 |
+
self.register_buffer("lambda_text_feats", lambda_text_feats)
|
167 |
+
|
168 |
+
else:
|
169 |
+
text_feats = encode_text(model_name, weight_name, text_prompts)
|
170 |
+
text_feats.requires_grad = False
|
171 |
+
self.register_buffer("text_feats", text_feats)
|
172 |
+
|
173 |
+
def _build_head(self) -> None:
|
174 |
+
in_channels = self.backbone.in_features
|
175 |
+
out_channels = self.backbone.out_features
|
176 |
+
if self.zero_inflated:
|
177 |
+
self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
|
178 |
+
self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
|
179 |
+
|
180 |
+
self.pi_head = conv1x1(in_channels, out_channels, bias=False)
|
181 |
+
self.lambda_head = conv1x1(in_channels, out_channels, bias=False)
|
182 |
+
|
183 |
+
else:
|
184 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
|
185 |
+
self.head = conv1x1(in_channels, out_channels, bias=False)
|
186 |
+
|
187 |
+
def forward(self, image: Tensor):
|
188 |
+
image_feats = self.backbone(image)
|
189 |
+
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
190 |
+
|
191 |
+
if self.zero_inflated:
|
192 |
+
pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats)
|
193 |
+
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
194 |
+
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
|
195 |
+
|
196 |
+
pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats
|
197 |
+
pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp()
|
198 |
+
|
199 |
+
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
|
200 |
+
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
|
201 |
+
|
202 |
+
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
|
203 |
+
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
|
204 |
+
|
205 |
+
lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
|
206 |
+
|
207 |
+
# pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros
|
208 |
+
den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W)
|
209 |
+
|
210 |
+
if self.training:
|
211 |
+
return pi_logit_map, lambda_logit_map, lambda_map, den_map
|
212 |
+
else:
|
213 |
+
return den_map
|
214 |
+
|
215 |
+
else:
|
216 |
+
image_feats = self.head(image_feats)
|
217 |
+
image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1)
|
218 |
+
|
219 |
+
text_feats = self.text_feats
|
220 |
+
logit_scale = self.logit_scale.exp()
|
221 |
+
|
222 |
+
logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image
|
223 |
+
logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W)
|
224 |
+
|
225 |
+
den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W)
|
226 |
+
|
227 |
+
if self.training:
|
228 |
+
return logit_map, den_map
|
229 |
+
else:
|
230 |
+
return den_map
|
231 |
+
|
232 |
+
|
233 |
+
def _clip_ebc(
|
234 |
+
model_name: str,
|
235 |
+
weight_name: str,
|
236 |
+
block_size: Optional[int] = None,
|
237 |
+
bins: Optional[List[Tuple[float, float]]] = None,
|
238 |
+
bin_centers: Optional[List[float]] = None,
|
239 |
+
zero_inflated: Optional[bool] = True,
|
240 |
+
num_vpt: Optional[int] = None,
|
241 |
+
vpt_drop: Optional[float] = None,
|
242 |
+
input_size: Optional[int] = None,
|
243 |
+
adapter: Optional[bool] = False,
|
244 |
+
adapter_reduction: Optional[int] = None,
|
245 |
+
lora: Optional[bool] = False,
|
246 |
+
lora_rank: Optional[int] = None,
|
247 |
+
lora_alpha: Optional[float] = None,
|
248 |
+
lora_dropout: Optional[float] = None,
|
249 |
+
text_prompts: Optional[List[str]] = None,
|
250 |
+
norm: Optional[str] = "none",
|
251 |
+
act: Optional[str] = "none",
|
252 |
+
) -> CLIP_EBC:
|
253 |
+
return CLIP_EBC(
|
254 |
+
model_name=model_name,
|
255 |
+
weight_name=weight_name,
|
256 |
+
block_size=block_size,
|
257 |
+
bins=bins,
|
258 |
+
bin_centers=bin_centers,
|
259 |
+
zero_inflated=zero_inflated,
|
260 |
+
num_vpt=num_vpt,
|
261 |
+
vpt_drop=vpt_drop,
|
262 |
+
input_size=input_size,
|
263 |
+
adapter=adapter,
|
264 |
+
adapter_reduction=adapter_reduction,
|
265 |
+
lora=lora,
|
266 |
+
lora_rank=lora_rank,
|
267 |
+
lora_alpha=lora_alpha,
|
268 |
+
lora_dropout=lora_dropout,
|
269 |
+
text_prompts=text_prompts,
|
270 |
+
norm=norm,
|
271 |
+
act=act,
|
272 |
+
)
|
models/clip_ebc/resnet.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
import open_clip
|
3 |
+
from peft import get_peft_model, LoraConfig
|
4 |
+
|
5 |
+
from ..utils import ConvRefine, ConvUpsample, ConvAdapter
|
6 |
+
from ..utils import _get_norm_layer, _get_activation
|
7 |
+
|
8 |
+
|
9 |
+
resnet_names_and_weights = {
|
10 |
+
"RN50": ["openai", "yfcc15m", "cc12m"],
|
11 |
+
"RN101": ["openai", "yfcc15m", "cc12m"],
|
12 |
+
"RN50x4": ["openai", "yfcc15m", "cc12m"],
|
13 |
+
"RN50x16": ["openai", "yfcc15m", "cc12m"],
|
14 |
+
"RN50x64": ["openai", "yfcc15m", "cc12m"],
|
15 |
+
}
|
16 |
+
|
17 |
+
refiner_channels = {
|
18 |
+
"RN50": 2048,
|
19 |
+
"RN101": 2048,
|
20 |
+
"RN50x4": 2560,
|
21 |
+
"RN50x16": 3072,
|
22 |
+
"RN50x64": 4096,
|
23 |
+
}
|
24 |
+
|
25 |
+
refiner_groups = {
|
26 |
+
"RN50": refiner_channels["RN50"] // 512, # 4
|
27 |
+
"RN101": refiner_channels["RN101"] // 512, # 4
|
28 |
+
"RN50x4": refiner_channels["RN50x4"] // 512, # 5
|
29 |
+
"RN50x16": refiner_channels["RN50x16"] // 512, # 6
|
30 |
+
"RN50x64": refiner_channels["RN50x64"] // 512, # 8
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class ResNet(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
model_name: str,
|
38 |
+
weight_name: str,
|
39 |
+
block_size: int = 16,
|
40 |
+
adapter: bool = False,
|
41 |
+
adapter_reduction: int = 4,
|
42 |
+
norm: str = "none",
|
43 |
+
act: str = "none"
|
44 |
+
) -> None:
|
45 |
+
super(ResNet, self).__init__()
|
46 |
+
assert model_name in resnet_names_and_weights, f"Model name should be one of {list(resnet_names_and_weights.keys())}, but got {model_name}."
|
47 |
+
assert weight_name in resnet_names_and_weights[model_name], f"Pretrained should be one of {resnet_names_and_weights[model_name]}, but got {weight_name}."
|
48 |
+
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
|
49 |
+
self.model_name, self.weight_name = model_name, weight_name
|
50 |
+
self.block_size = block_size
|
51 |
+
|
52 |
+
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
|
53 |
+
|
54 |
+
self.adapter = adapter
|
55 |
+
if adapter:
|
56 |
+
for param in model.parameters():
|
57 |
+
param.requires_grad = False
|
58 |
+
|
59 |
+
# Stem
|
60 |
+
self.conv1 = model.conv1
|
61 |
+
self.bn1 = model.bn1
|
62 |
+
self.act1 = model.act1
|
63 |
+
self.conv2 = model.conv2
|
64 |
+
self.bn2 = model.bn2
|
65 |
+
self.act2 = model.act2
|
66 |
+
self.conv3 = model.conv3
|
67 |
+
self.bn3 = model.bn3
|
68 |
+
self.act3 = model.act3
|
69 |
+
self.avgpool = model.avgpool
|
70 |
+
# Stem: reduction = 4
|
71 |
+
|
72 |
+
# Layers
|
73 |
+
for idx in range(1, 5):
|
74 |
+
setattr(self, f"layer{idx}", getattr(model, f"layer{idx}"))
|
75 |
+
if adapter:
|
76 |
+
setattr(self, f"adapter{idx}", ConvAdapter(
|
77 |
+
in_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels,
|
78 |
+
bottleneck_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels // adapter_reduction,
|
79 |
+
) if idx < 4 else nn.Identity()) # No adapter for the last layer
|
80 |
+
|
81 |
+
self.in_features = model.attnpool.c_proj.weight.shape[1]
|
82 |
+
self.out_features = model.attnpool.c_proj.weight.shape[0]
|
83 |
+
|
84 |
+
if norm == "bn":
|
85 |
+
norm_layer = nn.BatchNorm2d
|
86 |
+
elif norm == "ln":
|
87 |
+
norm_layer = nn.LayerNorm
|
88 |
+
else:
|
89 |
+
norm_layer = _get_norm_layer(model)
|
90 |
+
|
91 |
+
if act == "relu":
|
92 |
+
activation = nn.ReLU(inplace=True)
|
93 |
+
elif act == "gelu":
|
94 |
+
activation = nn.GELU()
|
95 |
+
else:
|
96 |
+
activation = _get_activation(model)
|
97 |
+
|
98 |
+
if block_size == 32:
|
99 |
+
self.refiner = ConvRefine(
|
100 |
+
in_channels=self.in_features,
|
101 |
+
out_channels=self.in_features,
|
102 |
+
norm_layer=norm_layer,
|
103 |
+
activation=activation,
|
104 |
+
groups=refiner_groups[self.model_name],
|
105 |
+
)
|
106 |
+
elif block_size == 16:
|
107 |
+
self.refiner = ConvUpsample(
|
108 |
+
in_channels=self.in_features,
|
109 |
+
out_channels=self.in_features,
|
110 |
+
norm_layer=norm_layer,
|
111 |
+
activation=activation,
|
112 |
+
groups=refiner_groups[self.model_name],
|
113 |
+
)
|
114 |
+
else: # block_size == 8
|
115 |
+
self.refiner = nn.Sequential(
|
116 |
+
ConvUpsample(
|
117 |
+
in_channels=self.in_features,
|
118 |
+
out_channels=self.in_features,
|
119 |
+
norm_layer=norm_layer,
|
120 |
+
activation=activation,
|
121 |
+
groups=refiner_groups[self.model_name],
|
122 |
+
),
|
123 |
+
ConvUpsample(
|
124 |
+
in_channels=self.in_features,
|
125 |
+
out_channels=self.in_features,
|
126 |
+
norm_layer=norm_layer,
|
127 |
+
activation=activation,
|
128 |
+
groups=refiner_groups[self.model_name],
|
129 |
+
),
|
130 |
+
)
|
131 |
+
|
132 |
+
def train(self, mode: bool = True):
|
133 |
+
if self.adapter and mode:
|
134 |
+
# training:
|
135 |
+
self.conv1.eval()
|
136 |
+
self.bn1.eval()
|
137 |
+
self.act1.eval()
|
138 |
+
self.conv2.eval()
|
139 |
+
self.bn2.eval()
|
140 |
+
self.act2.eval()
|
141 |
+
self.conv3.eval()
|
142 |
+
self.bn3.eval()
|
143 |
+
self.act3.eval()
|
144 |
+
self.avgpool.eval()
|
145 |
+
|
146 |
+
for idx in range(1, 5):
|
147 |
+
getattr(self, f"layer{idx}").eval()
|
148 |
+
getattr(self, f"adapter{idx}").train()
|
149 |
+
|
150 |
+
self.refiner.train()
|
151 |
+
|
152 |
+
else:
|
153 |
+
# evaluation:
|
154 |
+
for module in self.children():
|
155 |
+
module.train(mode)
|
156 |
+
|
157 |
+
def stem(self, x: Tensor) -> Tensor:
|
158 |
+
x = self.act1(self.bn1(self.conv1(x)))
|
159 |
+
x = self.act2(self.bn2(self.conv2(x)))
|
160 |
+
x = self.act3(self.bn3(self.conv3(x)))
|
161 |
+
x = self.avgpool(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def forward(self, x: Tensor) -> Tensor:
|
165 |
+
x = self.stem(x)
|
166 |
+
|
167 |
+
x = self.layer1(x)
|
168 |
+
if self.adapter:
|
169 |
+
x = self.adapter1(x)
|
170 |
+
|
171 |
+
x = self.layer2(x)
|
172 |
+
if self.adapter:
|
173 |
+
x = self.adapter2(x)
|
174 |
+
|
175 |
+
x = self.layer3(x)
|
176 |
+
if self.adapter:
|
177 |
+
x = self.adapter3(x)
|
178 |
+
|
179 |
+
x = self.layer4(x)
|
180 |
+
if self.adapter:
|
181 |
+
x = self.adapter4(x)
|
182 |
+
|
183 |
+
x = self.refiner(x)
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
def _resnet(
|
188 |
+
model_name: str,
|
189 |
+
weight_name: str,
|
190 |
+
block_size: int = 16,
|
191 |
+
adapter: bool = False,
|
192 |
+
adapter_reduction: int = 4,
|
193 |
+
lora: bool = False,
|
194 |
+
lora_rank: int = 16,
|
195 |
+
lora_alpha: float = 32.0,
|
196 |
+
lora_dropout: float = 0.1,
|
197 |
+
norm: str = "none",
|
198 |
+
act: str = "none"
|
199 |
+
) -> ResNet:
|
200 |
+
assert not (lora and adapter), "Lora and adapter cannot be used together."
|
201 |
+
model = ResNet(
|
202 |
+
model_name=model_name,
|
203 |
+
weight_name=weight_name,
|
204 |
+
block_size=block_size,
|
205 |
+
adapter=adapter,
|
206 |
+
adapter_reduction=adapter_reduction,
|
207 |
+
norm=norm,
|
208 |
+
act=act
|
209 |
+
)
|
210 |
+
|
211 |
+
if lora:
|
212 |
+
target_modules = []
|
213 |
+
for name, module in model.named_modules():
|
214 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
215 |
+
target_modules.append(name)
|
216 |
+
|
217 |
+
lora_config = LoraConfig(
|
218 |
+
r=lora_rank,
|
219 |
+
lora_alpha=lora_alpha,
|
220 |
+
lora_dropout=lora_dropout,
|
221 |
+
bias="none",
|
222 |
+
target_modules=target_modules,
|
223 |
+
)
|
224 |
+
model = get_peft_model(model, lora_config)
|
225 |
+
|
226 |
+
# Unfreeze BN layers
|
227 |
+
for name, module in model.named_modules():
|
228 |
+
if isinstance(module, nn.BatchNorm2d) and "refiner" not in name:
|
229 |
+
module.requires_grad_(True)
|
230 |
+
|
231 |
+
# Unfreeze refiner
|
232 |
+
for name, module in model.named_modules():
|
233 |
+
if "refiner" in name:
|
234 |
+
module.requires_grad_(True)
|
235 |
+
|
236 |
+
return model
|
models/clip_ebc/utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor, nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import open_clip
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
from typing import Union, Tuple, List
|
8 |
+
|
9 |
+
|
10 |
+
num_to_word = {
|
11 |
+
"0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine",
|
12 |
+
"10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen",
|
13 |
+
"20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine",
|
14 |
+
"30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine",
|
15 |
+
"40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine",
|
16 |
+
"50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine",
|
17 |
+
"60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine",
|
18 |
+
"70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine",
|
19 |
+
"80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine",
|
20 |
+
"90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine",
|
21 |
+
"100": "one hundred"
|
22 |
+
}
|
23 |
+
|
24 |
+
prefixes = [
|
25 |
+
"",
|
26 |
+
"A photo of", "A block of", "An image of", "A picture of",
|
27 |
+
"There are",
|
28 |
+
"The image contains", "The photo contains", "The picture contains",
|
29 |
+
"The image shows", "The photo shows", "The picture shows",
|
30 |
+
]
|
31 |
+
arabic_numeral = [True, False]
|
32 |
+
compares = [
|
33 |
+
"more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to",
|
34 |
+
"at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to",
|
35 |
+
"over", "above", "beyond", "exceeding", "surpassing",
|
36 |
+
]
|
37 |
+
suffixes = [
|
38 |
+
"people", "persons", "individuals", "humans", "faces", "heads", "figures", "",
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def num2word(num: Union[int, str]) -> str:
|
43 |
+
"""
|
44 |
+
Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc.
|
45 |
+
"""
|
46 |
+
num = str(int(num))
|
47 |
+
return num_to_word.get(num, num)
|
48 |
+
|
49 |
+
|
50 |
+
def format_count(
|
51 |
+
bins: List[Union[float, Tuple[float, float]]],
|
52 |
+
) -> List[List[str]]:
|
53 |
+
text_prompts = []
|
54 |
+
for prefix in prefixes:
|
55 |
+
for numeral in arabic_numeral:
|
56 |
+
for compare in compares:
|
57 |
+
for suffix in suffixes:
|
58 |
+
prompts = []
|
59 |
+
for bin in bins:
|
60 |
+
if isinstance(bin, (int, float)): # count is a single number
|
61 |
+
count = int(bin)
|
62 |
+
if count == 0 or count == 1:
|
63 |
+
count = num2word(count) if not numeral else count
|
64 |
+
prefix_ = "There is" if prefix == "There are" else prefix
|
65 |
+
suffix_ = "person" if suffix == "people" else suffix[:-1]
|
66 |
+
prompt = f"{prefix_} {count} {suffix_}"
|
67 |
+
else: # count > 1
|
68 |
+
count = num2word(count) if not numeral else count
|
69 |
+
prompt = f"{prefix} {count} {suffix}"
|
70 |
+
|
71 |
+
elif bin[1] == float("inf"): # count is (lower_bound, inf)
|
72 |
+
count = int(bin[0])
|
73 |
+
count = num2word(count) if not numeral else count
|
74 |
+
prompt = f"{prefix} {compare} {count} {suffix}"
|
75 |
+
|
76 |
+
else: # bin is (lower_bound, upper_bound)
|
77 |
+
left, right = int(bin[0]), int(bin[1])
|
78 |
+
left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right
|
79 |
+
prompt = f"{prefix} between {left} and {right} {suffix}"
|
80 |
+
|
81 |
+
# Remove starting and trailing whitespaces
|
82 |
+
prompt = prompt.strip() + "."
|
83 |
+
|
84 |
+
prompts.append(prompt)
|
85 |
+
|
86 |
+
text_prompts.append(prompts)
|
87 |
+
|
88 |
+
return text_prompts
|
89 |
+
|
90 |
+
|
91 |
+
def encode_text(
|
92 |
+
model_name: str,
|
93 |
+
weight_name: str,
|
94 |
+
text: List[str]
|
95 |
+
) -> Tensor:
|
96 |
+
if torch.cuda.is_available():
|
97 |
+
device = torch.device("cuda")
|
98 |
+
elif torch.mps.is_available():
|
99 |
+
device = torch.device("mps")
|
100 |
+
else:
|
101 |
+
device = torch.device("cpu")
|
102 |
+
text = open_clip.get_tokenizer(model_name)(text).to(device)
|
103 |
+
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device)
|
104 |
+
model.eval()
|
105 |
+
with torch.no_grad():
|
106 |
+
text_feats = model.encode_text(text)
|
107 |
+
text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu()
|
108 |
+
return text_feats
|
109 |
+
|
110 |
+
|
111 |
+
def optimize_text_prompts(
|
112 |
+
model_name: str,
|
113 |
+
weight_name: str,
|
114 |
+
flat_bins: List[Union[float, Tuple[float, float]]],
|
115 |
+
batch_size: int = 1024,
|
116 |
+
) -> List[str]:
|
117 |
+
text_prompts = format_count(flat_bins)
|
118 |
+
|
119 |
+
# Find the template that has the smallest average similarity of bin prompts.
|
120 |
+
print("Finding the best setup for text prompts...")
|
121 |
+
text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list
|
122 |
+
text_feats = []
|
123 |
+
for i in tqdm(range(0, len(text_prompts_), batch_size)):
|
124 |
+
text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))]))
|
125 |
+
text_feats = torch.cat(text_feats, dim=0)
|
126 |
+
|
127 |
+
sims = []
|
128 |
+
for idx, prompts in enumerate(text_prompts):
|
129 |
+
text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)]
|
130 |
+
sim = torch.mm(text_feats_, text_feats_.T)
|
131 |
+
sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item()
|
132 |
+
sims.append(sim)
|
133 |
+
|
134 |
+
optimal_prompts = text_prompts[np.argmin(sims)]
|
135 |
+
sim = sims[np.argmin(sims)]
|
136 |
+
print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})")
|
137 |
+
return optimal_prompts
|
models/clip_ebc/vit.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import math
|
4 |
+
from einops import rearrange
|
5 |
+
import open_clip
|
6 |
+
from peft import get_peft_model, LoraConfig
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
from ..utils import interpolate_pos_embed, ViTAdapter
|
10 |
+
# from ..utils import TransformerRefine, TransformerDownsample, TransformerUpsample
|
11 |
+
from ..utils import ConvRefine, ConvDownsample, ConvUpsample
|
12 |
+
from ..utils import _get_norm_layer, _get_activation
|
13 |
+
|
14 |
+
|
15 |
+
vit_names_and_weights = {
|
16 |
+
"ViT-B-32": [
|
17 |
+
"openai",
|
18 |
+
"laion400m_e31", "laion400m_e32", "laion2b_e16", "laion2b_s34b_b79k",
|
19 |
+
"datacomp_xl_s13b_b90k", "datacomp_m_s128m_b4k", "datacomp_s_s13m_b4k",
|
20 |
+
"commonpool_m_clip_s128m_b4k", "commonpool_m_laion_s128m_b4k", "commonpool_m_image_s128m_b4k", "commonpool_m_text_s128m_b4k", "commonpool_m_basic_s128m_b4k", "commonpool_m_s128m_b4k",
|
21 |
+
"commonpool_s_clip_s13m_b4k", "commonpool_s_laion_s13m_b4k", "commonpool_s_image_s13m_b4k", "commonpool_s_text_s13m_b4k", "commonpool_s_basic_s13m_b4k", "commonpool_s_s13m_b4k",
|
22 |
+
],
|
23 |
+
"ViT-B_32-256": ["datacomp_s34b_b86k"],
|
24 |
+
"ViT-B-16": [
|
25 |
+
"openai",
|
26 |
+
"laion400m_e31", "laion400m_e32", "laion2b_s34b_b88k",
|
27 |
+
"datacomp_xl_s13b_b90k", "datacomp_l_s1b_b8k",
|
28 |
+
"commonpool_l_clip_s1b_b8k", "commonpool_l_laion_s1b_b8k", "commonpool_l_image_s1b_b8k", "commonpool_l_text_s1b_b8k", "commonpool_l_basic_s1b_b8k", "commonpool_l_s1b_b8k",
|
29 |
+
"dfn2b"
|
30 |
+
],
|
31 |
+
"ViT-L-14": [
|
32 |
+
"openai",
|
33 |
+
"laion400m_e31", "laion400m_e32", "laion2b_s32b_b82k",
|
34 |
+
"datacomp_xl_s13b_b90k",
|
35 |
+
"commonpool_xl_clip_s13b_b90k", "commonpool_xl_laion_s13b_b90k", "commonpool_xl_s13b_b90k"
|
36 |
+
],
|
37 |
+
"ViT-L-14-336": ["openai"],
|
38 |
+
"ViT-H-14": ["laion2b_s32b_b79k"],
|
39 |
+
"ViT-g-14": ["laion2b_s12b_b42k", "laion2b_s34b_b88k"],
|
40 |
+
"ViT-bigG-14": ["laion2b_s39b_b160k"],
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
refiner_channels = {
|
45 |
+
"ViT-B-32": 768,
|
46 |
+
"ViT-B-32-256": 768,
|
47 |
+
"ViT-B-16": 768,
|
48 |
+
"ViT-L-14": 1024,
|
49 |
+
"ViT-L-14-336": 1024,
|
50 |
+
"ViT-H-14": 1280,
|
51 |
+
"ViT-g-14": 1408,
|
52 |
+
"ViT-bigG-14": 1664,
|
53 |
+
}
|
54 |
+
|
55 |
+
refiner_groups = {
|
56 |
+
"ViT-B-32": 1,
|
57 |
+
"ViT-B-32-256": 1,
|
58 |
+
"ViT-B-16": 1,
|
59 |
+
"ViT-L-14": 1,
|
60 |
+
"ViT-L-14-336": 1,
|
61 |
+
"ViT-H-14": 1,
|
62 |
+
"ViT-g-14": refiner_channels["ViT-g-14"] // 704, # 2
|
63 |
+
"ViT-bigG-14": refiner_channels["ViT-bigG-14"] // 416, # 4
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
class ViT(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
model_name: str,
|
72 |
+
weight_name: str,
|
73 |
+
block_size: int = 16,
|
74 |
+
num_vpt: int = 32,
|
75 |
+
vpt_drop: float = 0.0,
|
76 |
+
adapter: bool = False,
|
77 |
+
adapter_reduction: int = 4,
|
78 |
+
input_size: Optional[Tuple[int, int]] = None,
|
79 |
+
norm: str = "none",
|
80 |
+
act: str = "none"
|
81 |
+
) -> None:
|
82 |
+
super(ViT, self).__init__()
|
83 |
+
assert model_name in vit_names_and_weights, f"Model name should be one of {list(vit_names_and_weights.keys())}, but got {model_name}."
|
84 |
+
assert weight_name in vit_names_and_weights[model_name], f"Pretrained should be one of {vit_names_and_weights[model_name]}, but got {weight_name}."
|
85 |
+
if adapter:
|
86 |
+
assert num_vpt is None or num_vpt == 0, "num_vpt should be None or 0 when using adapter."
|
87 |
+
assert vpt_drop is None or vpt_drop == 0.0, "vpt_drop should be None or 0.0 when using adapter."
|
88 |
+
else:
|
89 |
+
assert num_vpt > 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
|
90 |
+
assert 0.0 <= vpt_drop < 1.0, f"VPT dropout should be in [0.0, 1.0), but got {vpt_drop}."
|
91 |
+
|
92 |
+
self.model_name, self.weight_name = model_name, weight_name
|
93 |
+
self.block_size = block_size
|
94 |
+
self.num_vpt = num_vpt
|
95 |
+
self.vpt_drop = vpt_drop
|
96 |
+
self.adapter = adapter
|
97 |
+
|
98 |
+
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
|
99 |
+
|
100 |
+
# Always freeze the parameters of the model
|
101 |
+
for param in model.parameters():
|
102 |
+
param.requires_grad = False
|
103 |
+
|
104 |
+
# Setup the model
|
105 |
+
self.input_size = input_size if input_size is not None else model.image_size
|
106 |
+
self.pretrain_size = model.image_size
|
107 |
+
self.patch_size = model.patch_size
|
108 |
+
self.class_embedding = model.class_embedding
|
109 |
+
self.positional_embedding = model.positional_embedding
|
110 |
+
self.embed_dim = model.class_embedding.shape[-1]
|
111 |
+
|
112 |
+
self.conv1 = model.conv1
|
113 |
+
self.ln_pre = model.ln_pre
|
114 |
+
self.resblocks = model.transformer.resblocks
|
115 |
+
self.num_layers = len(self.resblocks)
|
116 |
+
self.ln_post = model.ln_post
|
117 |
+
|
118 |
+
# Setup VPT tokens
|
119 |
+
val = math.sqrt(6. / float(3 * self.patch_size[0] + self.embed_dim))
|
120 |
+
for idx in range(self.num_layers):
|
121 |
+
if self.adapter:
|
122 |
+
setattr(self, f"adapter{idx}", ViTAdapter(
|
123 |
+
in_channels=self.embed_dim,
|
124 |
+
bottleneck_channels=self.embed_dim // adapter_reduction,
|
125 |
+
))
|
126 |
+
else:
|
127 |
+
setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.embed_dim)))
|
128 |
+
nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
|
129 |
+
setattr(self, f"vpt_drop_{idx}", nn.Dropout(self.vpt_drop))
|
130 |
+
|
131 |
+
# Adjust the positional embedding to match the new input size
|
132 |
+
self._adjust_pos_embed()
|
133 |
+
|
134 |
+
in_features, out_features = model.proj.shape
|
135 |
+
self.in_features = in_features
|
136 |
+
self.out_features = out_features
|
137 |
+
|
138 |
+
patch_size = self.patch_size[0]
|
139 |
+
if patch_size in [16, 32]:
|
140 |
+
assert block_size in [8, 16, 32], f"Patch size is 32, but got block size {block_size}."
|
141 |
+
else: # patch_size == 14
|
142 |
+
assert block_size in [7, 14, 28], f"Patch size is 14, but got block size {block_size}."
|
143 |
+
|
144 |
+
if norm == "bn":
|
145 |
+
norm_layer = nn.BatchNorm2d
|
146 |
+
elif norm == "ln":
|
147 |
+
norm_layer = nn.LayerNorm
|
148 |
+
else:
|
149 |
+
norm_layer = _get_norm_layer(model)
|
150 |
+
|
151 |
+
if act == "relu":
|
152 |
+
activation = nn.ReLU(inplace=True)
|
153 |
+
elif act == "gelu":
|
154 |
+
activation = nn.GELU()
|
155 |
+
else:
|
156 |
+
activation = _get_activation(model)
|
157 |
+
|
158 |
+
if block_size == patch_size:
|
159 |
+
self.refiner = ConvRefine(
|
160 |
+
in_channels=self.in_features,
|
161 |
+
out_channels=self.in_features,
|
162 |
+
norm_layer=norm_layer,
|
163 |
+
activation=activation,
|
164 |
+
groups=refiner_groups[self.model_name],
|
165 |
+
)
|
166 |
+
|
167 |
+
elif block_size < patch_size: # upsample
|
168 |
+
if block_size == 8 and patch_size == 32:
|
169 |
+
self.refiner = nn.Sequential(
|
170 |
+
ConvUpsample(
|
171 |
+
in_channels=self.in_features,
|
172 |
+
out_channels=self.in_features,
|
173 |
+
norm_layer=norm_layer,
|
174 |
+
activation=activation,
|
175 |
+
groups=refiner_groups[self.model_name],
|
176 |
+
),
|
177 |
+
ConvUpsample(
|
178 |
+
in_channels=self.in_features,
|
179 |
+
out_channels=self.in_features,
|
180 |
+
norm_layer=norm_layer,
|
181 |
+
activation=activation,
|
182 |
+
groups=refiner_groups[self.model_name],
|
183 |
+
),
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
self.refiner = ConvUpsample(
|
187 |
+
in_channels=self.in_features,
|
188 |
+
out_channels=self.in_features,
|
189 |
+
norm_layer=norm_layer,
|
190 |
+
activation=activation,
|
191 |
+
groups=refiner_groups[self.model_name],
|
192 |
+
)
|
193 |
+
|
194 |
+
else: # downsample
|
195 |
+
assert block_size // patch_size == 2, f"Block size {block_size} should be 2 times the patch size {patch_size}."
|
196 |
+
self.refiner = ConvDownsample(
|
197 |
+
in_channels=self.in_features,
|
198 |
+
out_channels=self.in_features,
|
199 |
+
norm_layer=norm_layer,
|
200 |
+
activation=activation,
|
201 |
+
groups=refiner_groups[self.model_name],
|
202 |
+
)
|
203 |
+
|
204 |
+
def _adjust_pos_embed(self) -> Tensor:
|
205 |
+
"""
|
206 |
+
Adjust the positional embedding to match the spatial resolution of the feature map.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
orig_h, orig_w: The original spatial resolution of the image.
|
210 |
+
new_h, new_w: The new spatial resolution of the image.
|
211 |
+
"""
|
212 |
+
self.positional_embedding = nn.Parameter(self._interpolate_pos_embed(self.pretrain_size[0], self.pretrain_size[1], self.input_size[0], self.input_size[1]), requires_grad=False)
|
213 |
+
|
214 |
+
def _interpolate_pos_embed(self, orig_h: int, orig_w: int, new_h: int, new_w: int) -> Tensor:
|
215 |
+
"""
|
216 |
+
Interpolate the positional embedding to match the spatial resolution of the feature map.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
orig_h, orig_w: The original spatial resolution of the image.
|
220 |
+
new_h, new_w: The new spatial resolution of the image.
|
221 |
+
"""
|
222 |
+
if (orig_h, orig_w) == (new_h, new_w):
|
223 |
+
return self.positional_embedding
|
224 |
+
|
225 |
+
orig_h_patches, orig_w_patches = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
|
226 |
+
new_h_patches, new_w_patches = new_h // self.patch_size[0], new_w // self.patch_size[1]
|
227 |
+
class_pos_embed, patch_pos_embed = self.positional_embedding[:1, :], self.positional_embedding[1:, :]
|
228 |
+
patch_pos_embed = rearrange(patch_pos_embed, "(h w) d -> d h w", h=orig_h_patches, w=orig_w_patches)
|
229 |
+
patch_pos_embed = interpolate_pos_embed(patch_pos_embed, size=(new_h_patches, new_w_patches))
|
230 |
+
patch_pos_embed = rearrange(patch_pos_embed, "d h w -> (h w) d")
|
231 |
+
pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
|
232 |
+
return pos_embed
|
233 |
+
|
234 |
+
def train(self, mode: bool = True):
|
235 |
+
if mode:
|
236 |
+
# training:
|
237 |
+
self.conv1.eval()
|
238 |
+
self.ln_pre.eval()
|
239 |
+
self.resblocks.eval()
|
240 |
+
self.ln_post.eval()
|
241 |
+
|
242 |
+
for idx in range(self.num_layers):
|
243 |
+
getattr(self, f"vpt_drop_{idx}").train()
|
244 |
+
|
245 |
+
self.refiner.train()
|
246 |
+
|
247 |
+
else:
|
248 |
+
# evaluation:
|
249 |
+
for module in self.children():
|
250 |
+
module.train(mode)
|
251 |
+
|
252 |
+
def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
|
253 |
+
vpt = getattr(self, f"vpt_{layer}").unsqueeze(0).expand(batch_size, -1, -1).to(device) # (batch_size, num_vpt, embed_dim)
|
254 |
+
vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
|
255 |
+
|
256 |
+
return vpt
|
257 |
+
|
258 |
+
def _forward_patch_embed(self, x: Tensor) -> Tensor:
|
259 |
+
# This step performs 1) embed x into patches; 2) append the class token; 3) add positional embeddings.
|
260 |
+
assert len(x.shape) == 4, f"Expected input to have shape (batch_size, 3, height, width), but got {x.shape}"
|
261 |
+
batch_size, _, height, width = x.shape
|
262 |
+
|
263 |
+
# Step 1: Embed x into patches
|
264 |
+
x = self.conv1(x)
|
265 |
+
|
266 |
+
# Step 2: Append the class token
|
267 |
+
class_embedding = self.class_embedding.expand(batch_size, 1, -1)
|
268 |
+
x = rearrange(x, "b d h w -> b (h w) d")
|
269 |
+
x = torch.cat([class_embedding, x], dim=1)
|
270 |
+
|
271 |
+
# Step 3: Add positional embeddings
|
272 |
+
pos_embed = self._interpolate_pos_embed(orig_h=self.input_size[0], orig_w=self.input_size[1], new_h=height, new_w=width).expand(batch_size, -1, -1)
|
273 |
+
x = x + pos_embed
|
274 |
+
|
275 |
+
x = self.ln_pre(x)
|
276 |
+
return x
|
277 |
+
|
278 |
+
def _forward_vpt(self, x: Tensor, idx: int) -> Tensor:
|
279 |
+
batch_size = x.shape[0]
|
280 |
+
device = x.device
|
281 |
+
|
282 |
+
# Assemble
|
283 |
+
vpt = self._prepare_vpt(idx, batch_size, device)
|
284 |
+
x = torch.cat([
|
285 |
+
x[:, :1, :], # class token
|
286 |
+
vpt,
|
287 |
+
x[:, 1:, :] # patches
|
288 |
+
], dim=1)
|
289 |
+
|
290 |
+
# Forward
|
291 |
+
x = self.resblocks[idx](x)
|
292 |
+
|
293 |
+
# Disassemble
|
294 |
+
x = torch.cat([
|
295 |
+
x[:, :1, :], # class token
|
296 |
+
x[:, 1 + self.num_vpt:, :] # patches
|
297 |
+
], dim=1)
|
298 |
+
|
299 |
+
return x
|
300 |
+
|
301 |
+
def _forward_adapter(self, x: Tensor, idx: int) -> Tensor:
|
302 |
+
return getattr(self, f"adapter{idx}")(x)
|
303 |
+
|
304 |
+
def forward_encoder(self, x: Tensor) -> Tensor:
|
305 |
+
x = self._forward_patch_embed(x)
|
306 |
+
for idx in range(self.num_layers):
|
307 |
+
x = self._forward_adapter(x, idx) if self.adapter else self._forward_vpt(x, idx)
|
308 |
+
x = self.ln_post(x)
|
309 |
+
return x
|
310 |
+
|
311 |
+
def forward(self, x: Tensor) -> Tensor:
|
312 |
+
orig_h, orig_w = x.shape[-2:]
|
313 |
+
num_patches_h, num_patches_w = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
|
314 |
+
x = self.forward_encoder(x)
|
315 |
+
x = x[:, 1:, :] # remove the class token
|
316 |
+
x = rearrange(x, "b (h w) d -> b d h w", h=num_patches_h, w=num_patches_w)
|
317 |
+
|
318 |
+
x = self.refiner(x)
|
319 |
+
return x
|
320 |
+
|
321 |
+
|
322 |
+
def _vit(
|
323 |
+
model_name: str,
|
324 |
+
weight_name: str,
|
325 |
+
block_size: int = 16,
|
326 |
+
num_vpt: int = 32,
|
327 |
+
vpt_drop: float = 0.1,
|
328 |
+
adapter: bool = False,
|
329 |
+
adapter_reduction: int = 4,
|
330 |
+
lora: bool = False,
|
331 |
+
lora_rank: int = 16,
|
332 |
+
lora_alpha: float = 32.0,
|
333 |
+
lora_dropout: float = 0.1,
|
334 |
+
input_size: Optional[Tuple[int, int]] = None,
|
335 |
+
norm: str = "none",
|
336 |
+
act: str = "none"
|
337 |
+
) -> ViT:
|
338 |
+
assert not (lora and adapter), "LoRA and adapter cannot be used together."
|
339 |
+
model = ViT(
|
340 |
+
model_name=model_name,
|
341 |
+
weight_name=weight_name,
|
342 |
+
block_size=block_size,
|
343 |
+
num_vpt=num_vpt,
|
344 |
+
vpt_drop=vpt_drop,
|
345 |
+
adapter=adapter,
|
346 |
+
adapter_reduction=adapter_reduction,
|
347 |
+
input_size=input_size,
|
348 |
+
norm=norm,
|
349 |
+
act=act
|
350 |
+
)
|
351 |
+
|
352 |
+
if lora:
|
353 |
+
target_modules = []
|
354 |
+
for name, module in model.named_modules():
|
355 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)) and "refiner" not in name:
|
356 |
+
target_modules.append(name)
|
357 |
+
|
358 |
+
lora_config = LoraConfig(
|
359 |
+
r=lora_rank,
|
360 |
+
lora_alpha=lora_alpha,
|
361 |
+
lora_dropout=lora_dropout,
|
362 |
+
bias="none",
|
363 |
+
target_modules=target_modules,
|
364 |
+
)
|
365 |
+
model = get_peft_model(model, lora_config)
|
366 |
+
|
367 |
+
# Unfreeze refiner
|
368 |
+
for name, module in model.named_modules():
|
369 |
+
if "refiner" in name:
|
370 |
+
module.requires_grad_(True)
|
371 |
+
|
372 |
+
return model
|
models/ebc/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .model import EBC, _ebc
|
2 |
+
|
3 |
+
__all__ = ["EBC", "_ebc"]
|
models/ebc/cannet.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from .csrnet import _csrnet, _csrnet_bn
|
7 |
+
from ..utils import _init_weights
|
8 |
+
|
9 |
+
EPS = 1e-6
|
10 |
+
|
11 |
+
|
12 |
+
class ContextualModule(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_channels: int,
|
16 |
+
out_channels: int = 512,
|
17 |
+
scales: List[int] = [1, 2, 3, 6],
|
18 |
+
) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.scales = scales
|
21 |
+
self.multiscale_modules = nn.ModuleList([self.__make_scale__(in_channels, size) for size in scales])
|
22 |
+
self.bottleneck = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1)
|
23 |
+
self.relu = nn.ReLU(inplace=True)
|
24 |
+
self.weight_net = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
25 |
+
self.apply(_init_weights)
|
26 |
+
|
27 |
+
def __make_weight__(self, feature: Tensor, scale_feature: Tensor) -> Tensor:
|
28 |
+
weight_feature = feature - scale_feature
|
29 |
+
weight_feature = self.weight_net(weight_feature)
|
30 |
+
return F.sigmoid(weight_feature)
|
31 |
+
|
32 |
+
def __make_scale__(self, channels: int, size: int) -> nn.Module:
|
33 |
+
return nn.Sequential(
|
34 |
+
nn.AdaptiveAvgPool2d(output_size=(size, size)),
|
35 |
+
nn.Conv2d(channels, channels, kernel_size=1, bias=False),
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, feature: Tensor) -> Tensor:
|
39 |
+
h, w = feature.shape[-2:]
|
40 |
+
multiscale_feats = [F.interpolate(input=scale(feature), size=(h, w), mode="bilinear") for scale in self.multiscale_modules]
|
41 |
+
weights = [self.__make_weight__(feature, scale_feature) for scale_feature in multiscale_feats]
|
42 |
+
multiscale_feats = sum([multiscale_feats[i] * weights[i] for i in range(len(weights))]) / (sum(weights) + EPS)
|
43 |
+
overall_features = torch.cat([multiscale_feats, feature], dim=1)
|
44 |
+
overall_features = self.bottleneck(overall_features)
|
45 |
+
overall_features = self.relu(overall_features)
|
46 |
+
return overall_features
|
47 |
+
|
48 |
+
|
49 |
+
class CANNet(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
model_name: str,
|
53 |
+
block_size: Optional[int] = None,
|
54 |
+
norm: str = "none",
|
55 |
+
act: str = "none",
|
56 |
+
scales: List[int] = [1, 2, 3, 6],
|
57 |
+
) -> None:
|
58 |
+
super().__init__()
|
59 |
+
assert model_name in ["csrnet", "csrnet_bn"], f"Model name should be one of ['csrnet', 'csrnet_bn'], but got {model_name}."
|
60 |
+
assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
|
61 |
+
assert isinstance(scales, (tuple, list)), f"scales should be a list or tuple, got {type(scales)}."
|
62 |
+
assert len(scales) > 0, f"Expected at least one size, got {len(scales)}."
|
63 |
+
assert all([isinstance(size, int) for size in scales]), f"Expected all size to be int, got {scales}."
|
64 |
+
self.model_name = model_name
|
65 |
+
self.scales = scales
|
66 |
+
|
67 |
+
csrnet = _csrnet(block_size=block_size, norm=norm, act=act) if model_name == "csrnet" else _csrnet_bn(block_size=block_size, norm=norm, act=act)
|
68 |
+
self.block_size = csrnet.block_size
|
69 |
+
|
70 |
+
self.encoder = csrnet.encoder
|
71 |
+
self.encoder_channels = csrnet.encoder_channels
|
72 |
+
self.encoder_reduction = csrnet.encoder_reduction # feature map size compared to input size
|
73 |
+
|
74 |
+
self.refiner = nn.Sequential(
|
75 |
+
csrnet.refiner,
|
76 |
+
ContextualModule(csrnet.refine_channels, 512, scales)
|
77 |
+
)
|
78 |
+
self.refiner_channels = 512
|
79 |
+
self.refiner_reduction = csrnet.refiner_reduction # feature map size compared to input size
|
80 |
+
|
81 |
+
self.decoder = csrnet.decoder
|
82 |
+
self.decoder_channels = csrnet.decoder_channels
|
83 |
+
self.decoder_reduction = csrnet.decoder_reduction
|
84 |
+
|
85 |
+
def encode(self, x: Tensor) -> Tensor:
|
86 |
+
return self.encoder(x)
|
87 |
+
|
88 |
+
def refine(self, x: Tensor) -> Tensor:
|
89 |
+
return self.refiner(x)
|
90 |
+
|
91 |
+
def decode(self, x: Tensor) -> Tensor:
|
92 |
+
return self.decoder(x)
|
93 |
+
|
94 |
+
def forward(self, x: Tensor) -> Tensor:
|
95 |
+
x = self.encode(x)
|
96 |
+
x = self.refine(x)
|
97 |
+
x = self.decode(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
def _cannet(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
|
102 |
+
return CANNet("csrnet", block_size=block_size, norm=norm, act=act, scales=scales)
|
103 |
+
|
104 |
+
def _cannet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
|
105 |
+
return CANNet("csrnet_bn", block_size=block_size, norm=norm, act=act, scales=scales)
|
models/ebc/csrnet.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
from torch.hub import load_state_dict_from_url
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from .vgg import VGG
|
6 |
+
from .utils import make_vgg_layers, vgg_urls
|
7 |
+
from ..utils import _init_weights, ConvDownsample, _get_activation, _get_norm_layer
|
8 |
+
|
9 |
+
EPS = 1e-6
|
10 |
+
|
11 |
+
|
12 |
+
encoder_cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512]
|
13 |
+
decoder_cfg = [512, 512, 512, 256, 128]
|
14 |
+
|
15 |
+
|
16 |
+
class CSRNet(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name: str,
|
20 |
+
block_size: Optional[int] = None,
|
21 |
+
norm: str = "none",
|
22 |
+
act: str = "none"
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
assert model_name in ["vgg16", "vgg16_bn"], f"Model name should be one of ['vgg16', 'vgg16_bn'], but got {model_name}."
|
26 |
+
assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
|
27 |
+
self.model_name = model_name
|
28 |
+
|
29 |
+
vgg = VGG(make_vgg_layers(encoder_cfg, in_channels=3, batch_norm="bn" in model_name, dilation=1))
|
30 |
+
vgg.load_state_dict(load_state_dict_from_url(vgg_urls[model_name]), strict=False)
|
31 |
+
self.encoder = vgg.features
|
32 |
+
self.encoder_reduction = 8
|
33 |
+
self.encoder_channels = 512
|
34 |
+
self.block_size = block_size if block_size is not None else 8
|
35 |
+
|
36 |
+
if norm == "bn":
|
37 |
+
norm_layer = nn.BatchNorm2d
|
38 |
+
elif norm == "ln":
|
39 |
+
norm_layer = nn.LayerNorm
|
40 |
+
else:
|
41 |
+
norm_layer = _get_norm_layer(vgg)
|
42 |
+
|
43 |
+
if act == "relu":
|
44 |
+
activation = nn.ReLU(inplace=True)
|
45 |
+
elif act == "gelu":
|
46 |
+
activation = nn.GELU()
|
47 |
+
else:
|
48 |
+
activation = _get_activation(vgg)
|
49 |
+
|
50 |
+
if self.block_size == self.encoder_reduction:
|
51 |
+
self.refiner = nn.Identity()
|
52 |
+
elif self.block_size > self.encoder_reduction:
|
53 |
+
if self.block_size == 32:
|
54 |
+
self.refiner = nn.Sequential(
|
55 |
+
ConvDownsample(
|
56 |
+
in_channels=self.encoder_channels,
|
57 |
+
out_channels=self.encoder_channels,
|
58 |
+
norm_layer=norm_layer,
|
59 |
+
activation=activation,
|
60 |
+
),
|
61 |
+
ConvDownsample(
|
62 |
+
in_channels=self.encoder_channels,
|
63 |
+
out_channels=self.encoder_channels,
|
64 |
+
norm_layer=norm_layer,
|
65 |
+
activation=activation,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
elif self.block_size == 16:
|
69 |
+
self.refiner = ConvDownsample(
|
70 |
+
in_channels=self.encoder_channels,
|
71 |
+
out_channels=self.encoder_channels,
|
72 |
+
norm_layer=norm_layer,
|
73 |
+
activation=activation,
|
74 |
+
)
|
75 |
+
self.refiner_channels = self.encoder_channels
|
76 |
+
self.refiner_reduction = self.block_size
|
77 |
+
|
78 |
+
decoder = make_vgg_layers(decoder_cfg, in_channels=512, batch_norm="bn" in model_name, dilation=2)
|
79 |
+
decoder.apply(_init_weights)
|
80 |
+
self.decoder = decoder
|
81 |
+
self.decoder_channels = decoder_cfg[-1]
|
82 |
+
self.decoder_reduction = self.refiner_reduction
|
83 |
+
|
84 |
+
def encode(self, x: Tensor) -> Tensor:
|
85 |
+
return self.encoder(x)
|
86 |
+
|
87 |
+
def refine(self, x: Tensor) -> Tensor:
|
88 |
+
return self.refiner(x)
|
89 |
+
|
90 |
+
def decode(self, x: Tensor) -> Tensor:
|
91 |
+
return self.decoder(x)
|
92 |
+
|
93 |
+
def forward(self, x: Tensor) -> Tensor:
|
94 |
+
x = self.encode(x)
|
95 |
+
x = self.refine(x)
|
96 |
+
x = self.decode(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
def _csrnet(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
|
101 |
+
return CSRNet("vgg16", block_size=block_size, norm=norm, act=act)
|
102 |
+
|
103 |
+
def _csrnet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
|
104 |
+
return CSRNet("vgg16_bn", block_size=block_size, norm=norm, act=act)
|
models/ebc/hrnet.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn, Tensor
|
4 |
+
from functools import partial
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from ..utils import ConvRefine, _get_norm_layer, _get_activation
|
8 |
+
|
9 |
+
|
10 |
+
available_hrnets = [
|
11 |
+
"hrnet_w18", "hrnet_w18_small", "hrnet_w18_small_v2",
|
12 |
+
"hrnet_w30", "hrnet_w32", "hrnet_w40", "hrnet_w44", "hrnet_w48", "hrnet_w64",
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class HRNet(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name: str,
|
20 |
+
block_size: Optional[int] = None,
|
21 |
+
norm: str = "none",
|
22 |
+
act: str = "none"
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
assert model_name in available_hrnets, f"Model name should be one of {available_hrnets}"
|
26 |
+
assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
|
27 |
+
self.model_name = model_name
|
28 |
+
self.block_size = block_size if block_size is not None else 32
|
29 |
+
|
30 |
+
model = timm.create_model(model_name, pretrained=True)
|
31 |
+
|
32 |
+
self.conv1 = model.conv1
|
33 |
+
self.bn1 = model.bn1
|
34 |
+
self.act1 = model.act1
|
35 |
+
self.conv2 = model.conv2
|
36 |
+
self.bn2 = model.bn2
|
37 |
+
self.act2 = model.act2
|
38 |
+
|
39 |
+
self.layer1 = model.layer1
|
40 |
+
|
41 |
+
self.transition1 = model.transition1
|
42 |
+
self.stage2 = model.stage2
|
43 |
+
|
44 |
+
self.transition2 = model.transition2
|
45 |
+
self.stage3 = model.stage3
|
46 |
+
|
47 |
+
self.transition3 = model.transition3
|
48 |
+
self.stage4 = model.stage4
|
49 |
+
|
50 |
+
incre_modules = model.incre_modules
|
51 |
+
downsamp_modules = model.downsamp_modules
|
52 |
+
|
53 |
+
assert len(incre_modules) == 4, f"Expected 4 incre_modules, got {len(self.incre_modules)}"
|
54 |
+
assert len(downsamp_modules) == 3, f"Expected 3 downsamp_modules, got {len(self.downsamp_modules)}"
|
55 |
+
|
56 |
+
self.out_channels_4 = incre_modules[0][0].downsample[0].out_channels
|
57 |
+
self.out_channels_8 = incre_modules[1][0].downsample[0].out_channels
|
58 |
+
self.out_channels_16 = incre_modules[2][0].downsample[0].out_channels
|
59 |
+
self.out_channels_32 = incre_modules[3][0].downsample[0].out_channels
|
60 |
+
|
61 |
+
if self.block_size == 8:
|
62 |
+
self.encoder_reduction = 8
|
63 |
+
self.encoder_channels = self.out_channels_8
|
64 |
+
self.incre_modules = incre_modules[:2]
|
65 |
+
self.downsamp_modules = downsamp_modules[:1]
|
66 |
+
|
67 |
+
self.refiner = nn.Identity()
|
68 |
+
self.refiner_reduction = 8
|
69 |
+
self.refiner_channels = self.out_channels_8
|
70 |
+
|
71 |
+
elif self.block_size == 16:
|
72 |
+
self.encoder_reduction = 16
|
73 |
+
self.encoder_channels = self.out_channels_16
|
74 |
+
self.incre_modules = incre_modules[:3]
|
75 |
+
self.downsamp_modules = downsamp_modules[:2]
|
76 |
+
|
77 |
+
self.refiner = nn.Identity()
|
78 |
+
self.refiner_reduction = 16
|
79 |
+
self.refiner_channels = self.out_channels_16
|
80 |
+
|
81 |
+
else: # self.block_size == 32
|
82 |
+
self.encoder_reduction = 32
|
83 |
+
self.encoder_channels = self.out_channels_32
|
84 |
+
self.incre_modules = incre_modules
|
85 |
+
self.downsamp_modules = downsamp_modules
|
86 |
+
|
87 |
+
self.refiner = nn.Identity()
|
88 |
+
self.refiner_reduction = 32
|
89 |
+
self.refiner_channels = self.out_channels_32
|
90 |
+
|
91 |
+
# define the decoder
|
92 |
+
if self.refiner_channels <= 512:
|
93 |
+
groups = 1
|
94 |
+
elif self.refiner_channels <= 1024:
|
95 |
+
groups = 2
|
96 |
+
elif self.refiner_channels <= 2048:
|
97 |
+
groups = 4
|
98 |
+
else:
|
99 |
+
groups = 8
|
100 |
+
|
101 |
+
if norm == "bn":
|
102 |
+
norm_layer = nn.BatchNorm2d
|
103 |
+
elif norm == "ln":
|
104 |
+
norm_layer = nn.LayerNorm
|
105 |
+
else:
|
106 |
+
norm_layer = _get_norm_layer(model)
|
107 |
+
|
108 |
+
if act == "relu":
|
109 |
+
activation = nn.ReLU(inplace=True)
|
110 |
+
elif act == "gelu":
|
111 |
+
activation = nn.GELU()
|
112 |
+
else:
|
113 |
+
activation = _get_activation(model)
|
114 |
+
|
115 |
+
decoder_block = partial(ConvRefine, groups=groups, norm_layer=norm_layer, activation=activation)
|
116 |
+
if self.refiner_channels <= 256:
|
117 |
+
self.decoder = nn.Identity()
|
118 |
+
self.decoder_channels = self.refiner_channels
|
119 |
+
elif self.refiner_channels <= 512:
|
120 |
+
self.decoder = decoder_block(self.refiner_channels, self.refiner_channels // 2)
|
121 |
+
self.decoder_channels = self.refiner_channels // 2
|
122 |
+
elif self.refiner_channels <= 1024:
|
123 |
+
self.decoder = nn.Sequential(
|
124 |
+
decoder_block(self.refiner_channels, self.refiner_channels // 2),
|
125 |
+
decoder_block(self.refiner_channels // 2, self.refiner_channels // 4),
|
126 |
+
)
|
127 |
+
self.decoder_channels = self.refiner_channels // 4
|
128 |
+
else:
|
129 |
+
self.decoder = nn.Sequential(
|
130 |
+
decoder_block(self.refiner_channels, self.refiner_channels // 2),
|
131 |
+
decoder_block(self.refiner_channels // 2, self.refiner_channels // 4),
|
132 |
+
decoder_block(self.refiner_channels // 4, self.refiner_channels // 8),
|
133 |
+
)
|
134 |
+
self.decoder_channels = self.refiner_channels // 8
|
135 |
+
|
136 |
+
self.decoder_reduction = self.refiner_reduction
|
137 |
+
|
138 |
+
def _interpolate(self, x: Tensor) -> Tensor:
|
139 |
+
# This method adjust the spatial dimensions of the input tensor so that it can be divided by 32.
|
140 |
+
if x.shape[-1] % 32 != 0 or x.shape[-2] % 32 != 0:
|
141 |
+
new_h = int(round(x.shape[-2] / 32) * 32)
|
142 |
+
new_w = int(round(x.shape[-1] / 32) * 32)
|
143 |
+
return F.interpolate(x, size=(new_h, new_w), mode="bicubic", align_corners=False)
|
144 |
+
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
def encode(self, x: Tensor) -> Tensor:
|
149 |
+
x = self.conv1(x)
|
150 |
+
x = self.bn1(x)
|
151 |
+
x = self.act1(x)
|
152 |
+
|
153 |
+
x = self.conv2(x)
|
154 |
+
x = self.bn2(x)
|
155 |
+
x = self.act2(x)
|
156 |
+
|
157 |
+
x = self.layer1(x)
|
158 |
+
|
159 |
+
x = [t(x) for t in self.transition1]
|
160 |
+
x = self.stage2(x)
|
161 |
+
|
162 |
+
x = [t(x[-1]) if not isinstance(t, nn.Identity) else x[i] for i, t in enumerate(self.transition2)]
|
163 |
+
x = self.stage3(x)
|
164 |
+
|
165 |
+
x = [t(x[-1]) if not isinstance(t, nn.Identity) else x[i] for i, t in enumerate(self.transition3)]
|
166 |
+
x = self.stage4(x)
|
167 |
+
|
168 |
+
assert len(x) == 4, f"Expected 4 outputs, got {len(x)}"
|
169 |
+
|
170 |
+
feats = None
|
171 |
+
for i, incre in enumerate(self.incre_modules):
|
172 |
+
if feats is None:
|
173 |
+
feats = incre(x[i])
|
174 |
+
else:
|
175 |
+
down = self.downsamp_modules[i - 1] # needed for torchscript module indexing
|
176 |
+
feats = incre(x[i]) + down.forward(feats)
|
177 |
+
|
178 |
+
return feats
|
179 |
+
|
180 |
+
def refine(self, x: Tensor) -> Tensor:
|
181 |
+
return self.refiner(x)
|
182 |
+
|
183 |
+
def decode(self, x: Tensor) -> Tensor:
|
184 |
+
return self.decoder(x)
|
185 |
+
|
186 |
+
def forward(self, x: Tensor) -> Tensor:
|
187 |
+
x = self._interpolate(x)
|
188 |
+
x = self.encode(x)
|
189 |
+
x = self.refine(x)
|
190 |
+
x = self.decode(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def _hrnet(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> HRNet:
|
195 |
+
return HRNet(model_name, block_size, norm, act)
|
models/ebc/model.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
from typing import Tuple, Union, Dict, Optional, List
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from .cannet import _cannet, _cannet_bn
|
9 |
+
from .csrnet import _csrnet, _csrnet_bn
|
10 |
+
from .vgg import _vgg_encoder_decoder, _vgg_encoder
|
11 |
+
from .vit import _vit, supported_vit_backbones
|
12 |
+
from .timm_models import _timm_model
|
13 |
+
from .timm_models import regular_models as timm_regular_models, heavy_models as timm_heavy_models, light_models as timm_light_models, lighter_models as timm_lighter_models
|
14 |
+
from .hrnet import _hrnet, available_hrnets
|
15 |
+
|
16 |
+
from ..utils import conv1x1
|
17 |
+
|
18 |
+
|
19 |
+
regular_models = [
|
20 |
+
"csrnet", "csrnet_bn",
|
21 |
+
"cannet", "cannet_bn",
|
22 |
+
"vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn",
|
23 |
+
"vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae",
|
24 |
+
*timm_regular_models,
|
25 |
+
*available_hrnets,
|
26 |
+
]
|
27 |
+
|
28 |
+
heavy_models = timm_heavy_models
|
29 |
+
|
30 |
+
light_models = timm_light_models
|
31 |
+
|
32 |
+
lighter_models = timm_lighter_models
|
33 |
+
|
34 |
+
transformer_models = supported_vit_backbones
|
35 |
+
|
36 |
+
supported_models = regular_models + heavy_models + light_models + lighter_models + transformer_models
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class EBC(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
model_name: str,
|
44 |
+
block_size: int,
|
45 |
+
bins: List[Tuple[float, float]],
|
46 |
+
bin_centers: List[float],
|
47 |
+
zero_inflated: bool = True,
|
48 |
+
num_vpt: Optional[int] = None,
|
49 |
+
vpt_drop: Optional[float] = None,
|
50 |
+
input_size: Optional[int] = None,
|
51 |
+
norm: str = "none",
|
52 |
+
act: str = "none"
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
assert model_name in supported_models, f"Model name should be one of {supported_models}, but got {model_name}."
|
56 |
+
self.model_name = model_name
|
57 |
+
|
58 |
+
if input_size is not None:
|
59 |
+
input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
|
60 |
+
assert len(input_size) == 2 and input_size[0] > 0 and input_size[1] > 0, f"Expected input_size to be a tuple of two positive integers, got {input_size}"
|
61 |
+
self.input_size = input_size
|
62 |
+
|
63 |
+
assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
|
64 |
+
assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
|
65 |
+
assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
|
66 |
+
bins = [(float(b[0]), float(b[1])) for b in bins]
|
67 |
+
assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
|
68 |
+
|
69 |
+
self.block_size = block_size
|
70 |
+
self.bins = bins
|
71 |
+
self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
|
72 |
+
|
73 |
+
self.zero_inflated = zero_inflated
|
74 |
+
self.num_vpt = num_vpt
|
75 |
+
self.vpt_drop = vpt_drop
|
76 |
+
self.input_size = input_size
|
77 |
+
|
78 |
+
self.norm = norm
|
79 |
+
self.act = act
|
80 |
+
|
81 |
+
self._build_backbone()
|
82 |
+
self._build_head()
|
83 |
+
|
84 |
+
def _build_backbone(self) -> None:
|
85 |
+
model_name = self.model_name
|
86 |
+
if model_name == "csrnet":
|
87 |
+
self.backbone = _csrnet(self.block_size, self.norm, self.act)
|
88 |
+
elif model_name == "csrnet_bn":
|
89 |
+
self.backbone = _csrnet_bn(self.block_size, self.norm, self.act)
|
90 |
+
elif model_name == "cannet":
|
91 |
+
self.backbone = _cannet(self.block_size, self.norm, self.act)
|
92 |
+
elif model_name == "cannet_bn":
|
93 |
+
self.backbone = _cannet_bn(self.block_size, self.norm, self.act)
|
94 |
+
elif model_name == "vgg11":
|
95 |
+
self.backbone = _vgg_encoder("vgg11", self.block_size, self.norm, self.act)
|
96 |
+
elif model_name == "vgg11_ae":
|
97 |
+
self.backbone = _vgg_encoder_decoder("vgg11", self.block_size, self.norm, self.act)
|
98 |
+
elif model_name == "vgg11_bn":
|
99 |
+
self.backbone = _vgg_encoder("vgg11_bn", self.block_size, self.norm, self.act)
|
100 |
+
elif model_name == "vgg11_bn_ae":
|
101 |
+
self.backbone = _vgg_encoder_decoder("vgg11_bn", self.block_size, self.norm, self.act)
|
102 |
+
elif model_name == "vgg13":
|
103 |
+
self.backbone = _vgg_encoder("vgg13", self.block_size, self.norm, self.act)
|
104 |
+
elif model_name == "vgg13_ae":
|
105 |
+
self.backbone = _vgg_encoder_decoder("vgg13", self.block_size, self.norm, self.act)
|
106 |
+
elif model_name == "vgg13_bn":
|
107 |
+
self.backbone = _vgg_encoder("vgg13_bn", self.block_size, self.norm, self.act)
|
108 |
+
elif model_name == "vgg13_bn_ae":
|
109 |
+
self.backbone = _vgg_encoder_decoder("vgg13_bn", self.block_size, self.norm, self.act)
|
110 |
+
elif model_name == "vgg16":
|
111 |
+
self.backbone = _vgg_encoder("vgg16", self.block_size, self.norm, self.act)
|
112 |
+
elif model_name == "vgg16_ae":
|
113 |
+
self.backbone = _vgg_encoder_decoder("vgg16", self.block_size, self.norm, self.act)
|
114 |
+
elif model_name == "vgg16_bn":
|
115 |
+
self.backbone = _vgg_encoder("vgg16_bn", self.block_size, self.norm, self.act)
|
116 |
+
elif model_name == "vgg16_bn_ae":
|
117 |
+
self.backbone = _vgg_encoder_decoder("vgg16_bn", self.block_size, self.norm, self.act)
|
118 |
+
elif model_name == "vgg19":
|
119 |
+
self.backbone = _vgg_encoder("vgg19", self.block_size, self.norm, self.act)
|
120 |
+
elif model_name == "vgg19_ae":
|
121 |
+
self.backbone = _vgg_encoder_decoder("vgg19", self.block_size, self.norm, self.act)
|
122 |
+
elif model_name == "vgg19_bn":
|
123 |
+
self.backbone = _vgg_encoder("vgg19_bn", self.block_size, self.norm, self.act)
|
124 |
+
elif model_name == "vgg19_bn_ae":
|
125 |
+
self.backbone = _vgg_encoder_decoder("vgg19_bn", self.block_size, self.norm, self.act)
|
126 |
+
elif model_name in supported_vit_backbones:
|
127 |
+
self.backbone = _vit(model_name, block_size=self.block_size, num_vpt=self.num_vpt, vpt_drop=self.vpt_drop, input_size=self.input_size, norm=self.norm, act=self.act)
|
128 |
+
elif model_name in available_hrnets:
|
129 |
+
self.backbone = _hrnet(model_name, block_size=self.block_size, norm=self.norm, act=self.act)
|
130 |
+
else:
|
131 |
+
self.backbone = _timm_model(model_name, self.block_size, self.norm, self.act)
|
132 |
+
|
133 |
+
def _build_head(self) -> None:
|
134 |
+
channels = self.backbone.decoder_channels
|
135 |
+
if self.zero_inflated:
|
136 |
+
self.bin_head = conv1x1(
|
137 |
+
in_channels=channels,
|
138 |
+
out_channels=len(self.bins) - 1,
|
139 |
+
)
|
140 |
+
self.pi_head = conv1x1(
|
141 |
+
in_channels=channels,
|
142 |
+
out_channels=2,
|
143 |
+
) # this models structural 0s.
|
144 |
+
else:
|
145 |
+
self.bin_head = conv1x1(
|
146 |
+
in_channels=channels,
|
147 |
+
out_channels=len(self.bins),
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
151 |
+
x = self.backbone(x)
|
152 |
+
|
153 |
+
if self.zero_inflated:
|
154 |
+
logit_pi_maps = self.pi_head(x) # shape: (B, 2, H, W)
|
155 |
+
logit_maps = self.bin_head(x) # shape: (B, C, H, W)
|
156 |
+
lambda_maps = (logit_maps.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W)
|
157 |
+
|
158 |
+
# logit_pi_maps.softmax(dim=1)[:, 0] is the probability of zeros
|
159 |
+
den_maps = logit_pi_maps.softmax(dim=1)[:, 1:] * lambda_maps # expectation of the Poisson distribution
|
160 |
+
|
161 |
+
if self.training:
|
162 |
+
return logit_pi_maps, logit_maps, lambda_maps, den_maps
|
163 |
+
else:
|
164 |
+
return den_maps
|
165 |
+
|
166 |
+
else:
|
167 |
+
logit_maps = self.bin_head(x)
|
168 |
+
den_maps = (logit_maps.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True)
|
169 |
+
|
170 |
+
if self.training:
|
171 |
+
return logit_maps, den_maps
|
172 |
+
else:
|
173 |
+
return den_maps
|
174 |
+
|
175 |
+
|
176 |
+
def _ebc(
|
177 |
+
model_name: str,
|
178 |
+
block_size: int,
|
179 |
+
bins: List[Tuple[float, float]],
|
180 |
+
bin_centers: List[float],
|
181 |
+
zero_inflated: bool = True,
|
182 |
+
num_vpt: Optional[int] = None,
|
183 |
+
vpt_drop: Optional[float] = None,
|
184 |
+
input_size: Optional[int] = None,
|
185 |
+
norm: str = "none",
|
186 |
+
act: str = "none"
|
187 |
+
) -> EBC:
|
188 |
+
return EBC(
|
189 |
+
model_name=model_name,
|
190 |
+
block_size=block_size,
|
191 |
+
bins=bins,
|
192 |
+
bin_centers=bin_centers,
|
193 |
+
zero_inflated=zero_inflated,
|
194 |
+
num_vpt=num_vpt,
|
195 |
+
vpt_drop=vpt_drop,
|
196 |
+
input_size=input_size,
|
197 |
+
norm=norm,
|
198 |
+
act=act
|
199 |
+
)
|
models/ebc/timm_models.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timm import create_model
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from typing import Optional
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample
|
7 |
+
from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample
|
8 |
+
from ..utils import ConvRefine, LightConvRefine, LighterConvRefine
|
9 |
+
|
10 |
+
regular_models = [
|
11 |
+
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
|
12 |
+
"convnext_nano", "convnext_tiny", "convnext_small", "convnext_base",
|
13 |
+
"mobilenetv4_conv_large",
|
14 |
+
]
|
15 |
+
|
16 |
+
heavy_models = [
|
17 |
+
"convnext_large", "convnext_xlarge", "convnext_xxlarge",
|
18 |
+
]
|
19 |
+
|
20 |
+
light_models = [
|
21 |
+
"mobilenetv1_100", "mobilenetv1_125",
|
22 |
+
"mobilenetv2_100", "mobilenetv2_140",
|
23 |
+
"mobilenetv3_large_100",
|
24 |
+
"mobilenetv4_conv_medium",
|
25 |
+
|
26 |
+
]
|
27 |
+
|
28 |
+
lighter_models = [
|
29 |
+
"mobilenetv2_050",
|
30 |
+
"mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100",
|
31 |
+
"mobilenetv4_conv_small_050", "mobilenetv4_conv_small"
|
32 |
+
]
|
33 |
+
|
34 |
+
supported_models = regular_models + heavy_models + light_models + lighter_models
|
35 |
+
|
36 |
+
|
37 |
+
refiner_in_channels = {
|
38 |
+
# ResNet
|
39 |
+
"resnet18": 512,
|
40 |
+
"resnet34": 512,
|
41 |
+
"resnet50": 2048,
|
42 |
+
"resnet101": 2048,
|
43 |
+
"resnet152": 2048,
|
44 |
+
# ConvNeXt
|
45 |
+
"convnext_nano": 640,
|
46 |
+
"convnext_tiny": 768,
|
47 |
+
"convnext_small": 768,
|
48 |
+
"convnext_base": 1024,
|
49 |
+
"convnext_large": 1536,
|
50 |
+
"convnext_xlarge": 2048,
|
51 |
+
"convnext_xxlarge": 3072,
|
52 |
+
# MobileNet V1
|
53 |
+
"mobilenetv1_100": 1024,
|
54 |
+
"mobilenetv1_125": 1280,
|
55 |
+
# MobileNet V2
|
56 |
+
"mobilenetv2_050": 160,
|
57 |
+
"mobilenetv2_100": 320,
|
58 |
+
"mobilenetv2_140": 448,
|
59 |
+
# MobileNet V3
|
60 |
+
"mobilenetv3_small_050": 288,
|
61 |
+
"mobilenetv3_small_075": 432,
|
62 |
+
"mobilenetv3_small_100": 576,
|
63 |
+
"mobilenetv3_large_100": 960,
|
64 |
+
# MobileNet V4
|
65 |
+
"mobilenetv4_conv_small_050": 480,
|
66 |
+
"mobilenetv4_conv_small": 960,
|
67 |
+
"mobilenetv4_conv_medium": 960,
|
68 |
+
"mobilenetv4_conv_large": 960,
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
refiner_out_channels = {
|
73 |
+
# ResNet
|
74 |
+
"resnet18": 512,
|
75 |
+
"resnet34": 512,
|
76 |
+
"resnet50": 2048,
|
77 |
+
"resnet101": 2048,
|
78 |
+
"resnet152": 2048,
|
79 |
+
# ConvNeXt
|
80 |
+
"convnext_nano": 640,
|
81 |
+
"convnext_tiny": 768,
|
82 |
+
"convnext_small": 768,
|
83 |
+
"convnext_base": 1024,
|
84 |
+
"convnext_large": 1536,
|
85 |
+
"convnext_xlarge": 2048,
|
86 |
+
"convnext_xxlarge": 3072,
|
87 |
+
# MobileNet V1
|
88 |
+
"mobilenetv1_100": 512,
|
89 |
+
"mobilenetv1_125": 640,
|
90 |
+
# MobileNet V2
|
91 |
+
"mobilenetv2_050": 160,
|
92 |
+
"mobilenetv2_100": 320,
|
93 |
+
"mobilenetv2_140": 448,
|
94 |
+
# MobileNet V3
|
95 |
+
"mobilenetv3_small_050": 288,
|
96 |
+
"mobilenetv3_small_075": 432,
|
97 |
+
"mobilenetv3_small_100": 576,
|
98 |
+
"mobilenetv3_large_100": 480,
|
99 |
+
# MobileNet V4
|
100 |
+
"mobilenetv4_conv_small_050": 480,
|
101 |
+
"mobilenetv4_conv_small": 960,
|
102 |
+
"mobilenetv4_conv_medium": 960,
|
103 |
+
"mobilenetv4_conv_large": 960,
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
groups = {
|
108 |
+
# ResNet
|
109 |
+
"resnet18": 1,
|
110 |
+
"resnet34": 1,
|
111 |
+
"resnet50": refiner_in_channels["resnet50"] // 512,
|
112 |
+
"resnet101": refiner_in_channels["resnet101"] // 512,
|
113 |
+
"resnet152": refiner_in_channels["resnet152"] // 512,
|
114 |
+
# ConvNeXt
|
115 |
+
"convnext_nano": 8,
|
116 |
+
"convnext_tiny": 8,
|
117 |
+
"convnext_small": 8,
|
118 |
+
"convnext_base": 8,
|
119 |
+
"convnext_large": refiner_in_channels["convnext_large"] // 512,
|
120 |
+
"convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512,
|
121 |
+
"convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512,
|
122 |
+
# MobileNet V1
|
123 |
+
"mobilenetv1_100": None,
|
124 |
+
"mobilenetv1_125": None,
|
125 |
+
# MobileNet V2
|
126 |
+
"mobilenetv2_050": None,
|
127 |
+
"mobilenetv2_100": None,
|
128 |
+
"mobilenetv2_140": None,
|
129 |
+
# MobileNet V3
|
130 |
+
"mobilenetv3_small_050": None,
|
131 |
+
"mobilenetv3_small_075": None,
|
132 |
+
"mobilenetv3_small_100": None,
|
133 |
+
"mobilenetv3_large_100": None,
|
134 |
+
# MobileNet V4
|
135 |
+
"mobilenetv4_conv_small_050": None,
|
136 |
+
"mobilenetv4_conv_small": None,
|
137 |
+
"mobilenetv4_conv_medium": None,
|
138 |
+
"mobilenetv4_conv_large": 1,
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
class TIMMModel(nn.Module):
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
model_name: str,
|
146 |
+
block_size: Optional[int] = None,
|
147 |
+
norm: str = "none",
|
148 |
+
act: str = "none"
|
149 |
+
) -> None:
|
150 |
+
super().__init__()
|
151 |
+
assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}"
|
152 |
+
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
|
153 |
+
self.model_name = model_name
|
154 |
+
self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1])
|
155 |
+
self.encoder_channels = self.encoder.feature_info.channels()[-1]
|
156 |
+
self.encoder_reduction = self.encoder.feature_info.reduction()[-1]
|
157 |
+
self.block_size = block_size if block_size is not None else self.encoder_reduction
|
158 |
+
|
159 |
+
if model_name in lighter_models:
|
160 |
+
upsample_block = LighterConvUpsample
|
161 |
+
downsample_block = LighterConvDownsample
|
162 |
+
decoder_block = LighterConvRefine
|
163 |
+
elif model_name in light_models:
|
164 |
+
upsample_block = LightConvUpsample
|
165 |
+
downsample_block = LightConvDownsample
|
166 |
+
decoder_block = LightConvRefine
|
167 |
+
else:
|
168 |
+
upsample_block = partial(ConvUpsample, groups=groups[model_name])
|
169 |
+
downsample_block = partial(ConvDownsample, groups=groups[model_name])
|
170 |
+
decoder_block = partial(ConvRefine, groups=groups[model_name])
|
171 |
+
|
172 |
+
|
173 |
+
if norm == "bn":
|
174 |
+
norm_layer = nn.BatchNorm2d
|
175 |
+
elif norm == "ln":
|
176 |
+
norm_layer = nn.LayerNorm
|
177 |
+
else:
|
178 |
+
norm_layer = _get_norm_layer(self.encoder)
|
179 |
+
|
180 |
+
if act == "relu":
|
181 |
+
activation = nn.ReLU(inplace=True)
|
182 |
+
elif act == "gelu":
|
183 |
+
activation = nn.GELU()
|
184 |
+
else:
|
185 |
+
activation = _get_activation(self.encoder)
|
186 |
+
|
187 |
+
if self.block_size > self.encoder_reduction:
|
188 |
+
if self.block_size > self.encoder_reduction * 2:
|
189 |
+
assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}."
|
190 |
+
self.refiner = nn.Sequential(
|
191 |
+
downsample_block(
|
192 |
+
in_channels=self.encoder_channels,
|
193 |
+
out_channels=refiner_in_channels[self.model_name],
|
194 |
+
norm_layer=norm_layer,
|
195 |
+
activation=activation,
|
196 |
+
),
|
197 |
+
downsample_block(
|
198 |
+
in_channels=refiner_in_channels[self.model_name],
|
199 |
+
out_channels=refiner_out_channels[self.model_name],
|
200 |
+
norm_layer=norm_layer,
|
201 |
+
activation=activation,
|
202 |
+
)
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}."
|
206 |
+
self.refiner = downsample_block(
|
207 |
+
in_channels=self.encoder_channels,
|
208 |
+
out_channels=refiner_out_channels[self.model_name],
|
209 |
+
norm_layer=norm_layer,
|
210 |
+
activation=activation,
|
211 |
+
)
|
212 |
+
|
213 |
+
self.refiner_channels = refiner_out_channels[self.model_name]
|
214 |
+
|
215 |
+
elif self.block_size < self.encoder_reduction:
|
216 |
+
if self.block_size < self.encoder_reduction // 2:
|
217 |
+
assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}."
|
218 |
+
self.refiner = nn.Sequential(
|
219 |
+
upsample_block(
|
220 |
+
in_channels=self.encoder_channels,
|
221 |
+
out_channels=refiner_in_channels[self.model_name],
|
222 |
+
norm_layer=norm_layer,
|
223 |
+
activation=activation,
|
224 |
+
),
|
225 |
+
upsample_block(
|
226 |
+
in_channels=refiner_in_channels[self.model_name],
|
227 |
+
out_channels=refiner_out_channels[self.model_name],
|
228 |
+
norm_layer=norm_layer,
|
229 |
+
activation=activation,
|
230 |
+
)
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}."
|
234 |
+
self.refiner = upsample_block(
|
235 |
+
in_channels=self.encoder_channels,
|
236 |
+
out_channels=refiner_out_channels[self.model_name],
|
237 |
+
norm_layer=norm_layer,
|
238 |
+
activation=activation,
|
239 |
+
)
|
240 |
+
|
241 |
+
self.refiner_channels = refiner_out_channels[self.model_name]
|
242 |
+
|
243 |
+
else:
|
244 |
+
self.refiner = nn.Identity()
|
245 |
+
self.refiner_channels = self.encoder_channels
|
246 |
+
|
247 |
+
self.refiner_reduction = self.block_size
|
248 |
+
|
249 |
+
if self.refiner_channels <= 256:
|
250 |
+
self.decoder = nn.Identity()
|
251 |
+
self.decoder_channels = self.refiner_channels
|
252 |
+
elif self.refiner_channels <= 512:
|
253 |
+
self.decoder = decoder_block(
|
254 |
+
in_channels=self.refiner_channels,
|
255 |
+
out_channels=self.refiner_channels // 2,
|
256 |
+
norm_layer=norm_layer,
|
257 |
+
activation=activation,
|
258 |
+
)
|
259 |
+
self.decoder_channels = self.refiner_channels // 2
|
260 |
+
elif self.refiner_channels <= 1024:
|
261 |
+
self.decoder = nn.Sequential(
|
262 |
+
decoder_block(
|
263 |
+
in_channels=self.refiner_channels,
|
264 |
+
out_channels=self.refiner_channels // 2,
|
265 |
+
norm_layer=norm_layer,
|
266 |
+
activation=activation,
|
267 |
+
),
|
268 |
+
decoder_block(
|
269 |
+
in_channels=self.refiner_channels // 2,
|
270 |
+
out_channels=self.refiner_channels // 4,
|
271 |
+
norm_layer=norm_layer,
|
272 |
+
activation=activation,
|
273 |
+
),
|
274 |
+
)
|
275 |
+
self.decoder_channels = self.refiner_channels // 4
|
276 |
+
else:
|
277 |
+
self.decoder = nn.Sequential(
|
278 |
+
decoder_block(
|
279 |
+
in_channels=self.refiner_channels,
|
280 |
+
out_channels=self.refiner_channels // 2,
|
281 |
+
norm_layer=norm_layer,
|
282 |
+
activation=activation,
|
283 |
+
),
|
284 |
+
decoder_block(
|
285 |
+
in_channels=self.refiner_channels // 2,
|
286 |
+
out_channels=self.refiner_channels // 4,
|
287 |
+
norm_layer=norm_layer,
|
288 |
+
activation=activation,
|
289 |
+
),
|
290 |
+
decoder_block(
|
291 |
+
in_channels=self.refiner_channels // 4,
|
292 |
+
out_channels=self.refiner_channels // 8,
|
293 |
+
norm_layer=norm_layer,
|
294 |
+
activation=activation,
|
295 |
+
),
|
296 |
+
)
|
297 |
+
self.decoder_channels = self.refiner_channels // 8
|
298 |
+
|
299 |
+
self.decoder_reduction = self.refiner_reduction
|
300 |
+
|
301 |
+
def encode(self, x: Tensor) -> Tensor:
|
302 |
+
return self.encoder(x)[0]
|
303 |
+
|
304 |
+
def refine(self, x: Tensor) -> Tensor:
|
305 |
+
return self.refiner(x)
|
306 |
+
|
307 |
+
def decode(self, x: Tensor) -> Tensor:
|
308 |
+
return self.decoder(x)
|
309 |
+
|
310 |
+
def forward(self, x: Tensor) -> Tensor:
|
311 |
+
x = self.encode(x)
|
312 |
+
x = self.refine(x)
|
313 |
+
x = self.decode(x)
|
314 |
+
return x
|
315 |
+
|
316 |
+
|
317 |
+
def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel:
|
318 |
+
return TIMMModel(model_name, block_size=block_size, norm=norm, act=act)
|
models/ebc/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from typing import Union, List, List
|
3 |
+
|
4 |
+
|
5 |
+
vgg_urls = {
|
6 |
+
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
7 |
+
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
8 |
+
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
9 |
+
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
10 |
+
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
11 |
+
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
12 |
+
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
13 |
+
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
vgg_cfgs = {
|
18 |
+
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
|
19 |
+
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
|
20 |
+
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512],
|
21 |
+
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512]
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def make_vgg_layers(cfg: List[Union[str, int]], in_channels: int = 3, batch_norm: bool = False, dilation: int = 1) -> nn.Sequential:
|
26 |
+
layers = []
|
27 |
+
for v in cfg:
|
28 |
+
if v == "M":
|
29 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
30 |
+
else:
|
31 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=dilation, dilation=dilation)
|
32 |
+
if batch_norm:
|
33 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
34 |
+
else:
|
35 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
36 |
+
in_channels = v
|
37 |
+
return nn.Sequential(*layers)
|
models/ebc/vgg.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
from torch.hub import load_state_dict_from_url
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from .utils import make_vgg_layers, vgg_cfgs, vgg_urls
|
6 |
+
from ..utils import _init_weights, _get_norm_layer, _get_activation
|
7 |
+
from ..utils import ConvDownsample, ConvUpsample
|
8 |
+
|
9 |
+
|
10 |
+
vgg_models = [
|
11 |
+
"vgg11", "vgg11_bn",
|
12 |
+
"vgg13", "vgg13_bn",
|
13 |
+
"vgg16", "vgg16_bn",
|
14 |
+
"vgg19", "vgg19_bn",
|
15 |
+
]
|
16 |
+
|
17 |
+
decoder_cfg = [512, 256, 128]
|
18 |
+
|
19 |
+
|
20 |
+
class VGGEncoder(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_name: str,
|
24 |
+
block_size: Optional[int] = None,
|
25 |
+
norm: str = "none",
|
26 |
+
act: str = "none",
|
27 |
+
) -> None:
|
28 |
+
super().__init__()
|
29 |
+
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
|
30 |
+
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
|
31 |
+
self.model_name = model_name
|
32 |
+
|
33 |
+
if model_name == "vgg11":
|
34 |
+
self.encoder = vgg11()
|
35 |
+
elif model_name == "vgg11_bn":
|
36 |
+
self.encoder = vgg11_bn()
|
37 |
+
elif model_name == "vgg13":
|
38 |
+
self.encoder = vgg13()
|
39 |
+
elif model_name == "vgg13_bn":
|
40 |
+
self.encoder = vgg13_bn()
|
41 |
+
elif model_name == "vgg16":
|
42 |
+
self.encoder = vgg16()
|
43 |
+
elif model_name == "vgg16_bn":
|
44 |
+
self.encoder = vgg16_bn()
|
45 |
+
elif model_name == "vgg19":
|
46 |
+
self.encoder = vgg19()
|
47 |
+
else: # model_name == "vgg19_bn"
|
48 |
+
self.encoder = vgg19_bn()
|
49 |
+
|
50 |
+
self.encoder_channels = 512
|
51 |
+
self.encoder_reduction = 16
|
52 |
+
self.block_size = block_size if block_size is not None else self.encoder_reduction
|
53 |
+
|
54 |
+
if norm == "bn":
|
55 |
+
norm_layer = nn.BatchNorm2d
|
56 |
+
elif norm == "ln":
|
57 |
+
norm_layer = nn.LayerNorm
|
58 |
+
else:
|
59 |
+
norm_layer = _get_norm_layer(self.encoder)
|
60 |
+
|
61 |
+
if act == "relu":
|
62 |
+
activation = nn.ReLU(inplace=True)
|
63 |
+
elif act == "gelu":
|
64 |
+
activation = nn.GELU()
|
65 |
+
else:
|
66 |
+
activation = _get_activation(self.encoder)
|
67 |
+
|
68 |
+
if self.encoder_reduction >= self.block_size: # 8, 16
|
69 |
+
self.refiner = ConvUpsample(
|
70 |
+
in_channels=self.encoder_channels,
|
71 |
+
out_channels=self.encoder_channels,
|
72 |
+
scale_factor=self.encoder_reduction // self.block_size,
|
73 |
+
norm_layer=norm_layer,
|
74 |
+
activation=activation,
|
75 |
+
)
|
76 |
+
else: # 32
|
77 |
+
self.refiner = ConvDownsample(
|
78 |
+
in_channels=self.encoder_channels,
|
79 |
+
out_channels=self.encoder_channels,
|
80 |
+
norm_layer=norm_layer,
|
81 |
+
activation=activation,
|
82 |
+
)
|
83 |
+
self.refiner_channels = self.encoder_channels
|
84 |
+
self.refiner_reduction = self.block_size
|
85 |
+
|
86 |
+
self.decoder = nn.Identity()
|
87 |
+
self.decoder_channels = self.encoder_channels
|
88 |
+
self.decoder_reduction = self.refiner_reduction
|
89 |
+
|
90 |
+
def encode(self, x: Tensor) -> Tensor:
|
91 |
+
return self.encoder(x)
|
92 |
+
|
93 |
+
def refine(self, x: Tensor) -> Tensor:
|
94 |
+
return self.refiner(x)
|
95 |
+
|
96 |
+
def decode(self, x: Tensor) -> Tensor:
|
97 |
+
return self.decoder(x)
|
98 |
+
|
99 |
+
def forward(self, x: Tensor) -> Tensor:
|
100 |
+
x = self.encode(x)
|
101 |
+
x = self.refine(x)
|
102 |
+
x = self.decode(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class VGGEncoderDecoder(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
model_name: str,
|
110 |
+
block_size: Optional[int] = None,
|
111 |
+
norm: str = "none",
|
112 |
+
act: str = "none",
|
113 |
+
) -> None:
|
114 |
+
super().__init__()
|
115 |
+
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
|
116 |
+
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
|
117 |
+
self.model_name = model_name
|
118 |
+
|
119 |
+
if model_name == "vgg11":
|
120 |
+
encoder = vgg11()
|
121 |
+
elif model_name == "vgg11_bn":
|
122 |
+
encoder = vgg11_bn()
|
123 |
+
elif model_name == "vgg13":
|
124 |
+
encoder = vgg13()
|
125 |
+
elif model_name == "vgg13_bn":
|
126 |
+
encoder = vgg13_bn()
|
127 |
+
elif model_name == "vgg16":
|
128 |
+
encoder = vgg16()
|
129 |
+
elif model_name == "vgg16_bn":
|
130 |
+
encoder = vgg16_bn()
|
131 |
+
elif model_name == "vgg19":
|
132 |
+
encoder = vgg19()
|
133 |
+
else: # model_name == "vgg19_bn"
|
134 |
+
encoder = vgg19_bn()
|
135 |
+
|
136 |
+
encoder_channels = 512
|
137 |
+
encoder_reduction = 16
|
138 |
+
decoder = make_vgg_layers(decoder_cfg, in_channels=encoder_channels, batch_norm="bn" in model_name, dilation=1)
|
139 |
+
decoder.apply(_init_weights)
|
140 |
+
|
141 |
+
if norm == "bn":
|
142 |
+
norm_layer = nn.BatchNorm2d
|
143 |
+
elif norm == "ln":
|
144 |
+
norm_layer = nn.LayerNorm
|
145 |
+
else:
|
146 |
+
norm_layer = _get_norm_layer(encoder)
|
147 |
+
|
148 |
+
if act == "relu":
|
149 |
+
activation = nn.ReLU(inplace=True)
|
150 |
+
elif act == "gelu":
|
151 |
+
activation = nn.GELU()
|
152 |
+
else:
|
153 |
+
activation = _get_activation(encoder)
|
154 |
+
|
155 |
+
self.encoder = nn.Sequential(encoder, decoder)
|
156 |
+
self.encoder_channels = decoder_cfg[-1]
|
157 |
+
self.encoder_reduction = encoder_reduction
|
158 |
+
self.block_size = block_size if block_size is not None else self.encoder_reduction
|
159 |
+
|
160 |
+
if self.encoder_reduction >= self.block_size:
|
161 |
+
self.refiner = ConvUpsample(
|
162 |
+
in_channels=self.encoder_channels,
|
163 |
+
out_channels=self.encoder_channels,
|
164 |
+
scale_factor=self.encoder_reduction // self.block_size,
|
165 |
+
norm_layer=norm_layer,
|
166 |
+
activation=activation,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.refiner = ConvDownsample(
|
170 |
+
in_channels=self.encoder_channels,
|
171 |
+
out_channels=self.encoder_channels,
|
172 |
+
norm_layer=norm_layer,
|
173 |
+
activation=activation,
|
174 |
+
)
|
175 |
+
self.refiner_channels = self.encoder_channels
|
176 |
+
self.refiner_reduction = self.block_size
|
177 |
+
|
178 |
+
self.decoder = nn.Identity()
|
179 |
+
self.decoder_channels = self.refiner_channels
|
180 |
+
self.decoder_reduction = self.refiner_reduction
|
181 |
+
|
182 |
+
def encode(self, x: Tensor) -> Tensor:
|
183 |
+
return self.encoder(x)
|
184 |
+
|
185 |
+
def refine(self, x: Tensor) -> Tensor:
|
186 |
+
return self.refiner(x)
|
187 |
+
|
188 |
+
def decode(self, x: Tensor) -> Tensor:
|
189 |
+
return self.decoder(x)
|
190 |
+
|
191 |
+
def forward(self, x: Tensor) -> Tensor:
|
192 |
+
x = self.encode(x)
|
193 |
+
x = self.refine(x)
|
194 |
+
x = self.decode(x)
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
class VGG(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
features: nn.Module,
|
202 |
+
) -> None:
|
203 |
+
super().__init__()
|
204 |
+
self.features = features
|
205 |
+
|
206 |
+
def forward(self, x: Tensor) -> Tensor:
|
207 |
+
x = self.features(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
def vgg11() -> VGG:
|
212 |
+
model = VGG(make_vgg_layers(vgg_cfgs["A"]))
|
213 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11"]), strict=False)
|
214 |
+
return model
|
215 |
+
|
216 |
+
def vgg11_bn() -> VGG:
|
217 |
+
model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True))
|
218 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11_bn"]), strict=False)
|
219 |
+
return model
|
220 |
+
|
221 |
+
def vgg13() -> VGG:
|
222 |
+
model = VGG(make_vgg_layers(vgg_cfgs["B"]))
|
223 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13"]), strict=False)
|
224 |
+
return model
|
225 |
+
|
226 |
+
def vgg13_bn() -> VGG:
|
227 |
+
model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True))
|
228 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13_bn"]), strict=False)
|
229 |
+
return model
|
230 |
+
|
231 |
+
def vgg16() -> VGG:
|
232 |
+
model = VGG(make_vgg_layers(vgg_cfgs["D"]))
|
233 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16"]), strict=False)
|
234 |
+
return model
|
235 |
+
|
236 |
+
def vgg16_bn() -> VGG:
|
237 |
+
model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True))
|
238 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16_bn"]), strict=False)
|
239 |
+
return model
|
240 |
+
|
241 |
+
def vgg19() -> VGG:
|
242 |
+
model = VGG(make_vgg_layers(vgg_cfgs["E"]))
|
243 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19"]), strict=False)
|
244 |
+
return model
|
245 |
+
|
246 |
+
def vgg19_bn() -> VGG:
|
247 |
+
model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True))
|
248 |
+
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19_bn"]), strict=False)
|
249 |
+
return model
|
250 |
+
|
251 |
+
def _vgg_encoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoder:
|
252 |
+
return VGGEncoder(model_name, block_size, norm=norm, act=act)
|
253 |
+
|
254 |
+
def _vgg_encoder_decoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoderDecoder:
|
255 |
+
return VGGEncoderDecoder(model_name, block_size, norm=norm, act=act)
|
models/ebc/vit.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import timm
|
4 |
+
from einops import rearrange
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
from ..utils import ConvUpsample, ConvDownsample, _get_activation, _get_norm_layer, ConvRefine
|
10 |
+
|
11 |
+
|
12 |
+
supported_vit_backbones = [
|
13 |
+
# Tiny
|
14 |
+
"vit_tiny_patch16_224", "vit_tiny_patch16_384",
|
15 |
+
# Small
|
16 |
+
"vit_small_patch8_224",
|
17 |
+
"vit_small_patch16_224", "vit_small_patch16_384",
|
18 |
+
"vit_small_patch32_224", "vit_small_patch32_384",
|
19 |
+
# Base
|
20 |
+
"vit_base_patch8_224",
|
21 |
+
"vit_base_patch16_224", "vit_base_patch16_384",
|
22 |
+
"vit_base_patch32_224", "vit_base_patch32_384",
|
23 |
+
# Large
|
24 |
+
"vit_large_patch16_224", "vit_large_patch16_384",
|
25 |
+
"vit_large_patch32_224", "vit_large_patch32_384",
|
26 |
+
# Huge
|
27 |
+
"vit_huge_patch14_224",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
refiner_channels = {
|
32 |
+
"vit_tiny_patch16_224": 192,
|
33 |
+
"vit_tiny_patch16_384": 192,
|
34 |
+
"vit_small_patch8_224": 384,
|
35 |
+
"vit_small_patch16_224": 384,
|
36 |
+
"vit_small_patch16_384": 384,
|
37 |
+
"vit_small_patch32_224": 384,
|
38 |
+
"vit_small_patch32_384": 384,
|
39 |
+
"vit_base_patch8_224": 768,
|
40 |
+
"vit_base_patch16_224": 768,
|
41 |
+
"vit_base_patch16_384": 768,
|
42 |
+
"vit_base_patch32_224": 768,
|
43 |
+
"vit_base_patch32_384": 768,
|
44 |
+
"vit_large_patch16_224": 1024,
|
45 |
+
"vit_large_patch16_384": 1024,
|
46 |
+
"vit_large_patch32_224": 1024,
|
47 |
+
"vit_large_patch32_384": 1024,
|
48 |
+
}
|
49 |
+
|
50 |
+
refiner_groups = {
|
51 |
+
"vit_tiny_patch16_224": 1,
|
52 |
+
"vit_tiny_patch16_384": 1,
|
53 |
+
"vit_small_patch8_224": 1,
|
54 |
+
"vit_small_patch16_224": 1,
|
55 |
+
"vit_small_patch16_384": 1,
|
56 |
+
"vit_small_patch32_224": 1,
|
57 |
+
"vit_small_patch32_384": 1,
|
58 |
+
"vit_base_patch8_224": 1,
|
59 |
+
"vit_base_patch16_224": 1,
|
60 |
+
"vit_base_patch16_384": 1,
|
61 |
+
"vit_base_patch32_224": 1,
|
62 |
+
"vit_base_patch32_384": 1,
|
63 |
+
"vit_large_patch16_224": 1,
|
64 |
+
"vit_large_patch16_384": 1,
|
65 |
+
"vit_large_patch32_224": 1,
|
66 |
+
"vit_large_patch32_384": 1,
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
class ViT(nn.Module):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
model_name: str,
|
74 |
+
block_size: Optional[int] = None,
|
75 |
+
num_vpt: int = 32,
|
76 |
+
vpt_drop: float = 0.0,
|
77 |
+
input_size: Optional[Tuple[int, int]] = None,
|
78 |
+
norm: str = "none",
|
79 |
+
act: str = "none"
|
80 |
+
) -> None:
|
81 |
+
super().__init__()
|
82 |
+
assert model_name in supported_vit_backbones, f"Model {model_name} not supported"
|
83 |
+
assert num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
|
84 |
+
self.model_name = model_name
|
85 |
+
|
86 |
+
self.num_vpt = num_vpt
|
87 |
+
self.vpt_drop = vpt_drop
|
88 |
+
|
89 |
+
model = timm.create_model(model_name, pretrained=True)
|
90 |
+
|
91 |
+
self.input_size = input_size if input_size is not None else model.patch_embed.img_size
|
92 |
+
self.pretrain_size = model.patch_embed.img_size
|
93 |
+
self.patch_size = model.patch_embed.patch_size
|
94 |
+
|
95 |
+
if self.patch_size[0] in [8, 16, 32]:
|
96 |
+
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
|
97 |
+
else: # patch_size == 14
|
98 |
+
assert block_size is None or block_size in [7, 14, 28], f"Block size should be one of [7, 14, 28], but got {block_size}."
|
99 |
+
|
100 |
+
self.num_layers = len(model.blocks)
|
101 |
+
self.embed_dim = model.cls_token.shape[-1]
|
102 |
+
|
103 |
+
if self.num_vpt > 0: # Use visual prompt tuning so freeze the backbone
|
104 |
+
for param in model.parameters():
|
105 |
+
param.requires_grad = False
|
106 |
+
|
107 |
+
# Setup VPT tokens
|
108 |
+
val = math.sqrt(6. / float(3 * self.patch_size[0] + self.embed_dim))
|
109 |
+
for idx in range(self.num_layers):
|
110 |
+
setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.embed_dim)))
|
111 |
+
nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
|
112 |
+
setattr(self, f"vpt_drop_{idx}", nn.Dropout(self.vpt_drop))
|
113 |
+
|
114 |
+
self.patch_embed = model.patch_embed
|
115 |
+
self.cls_token = model.cls_token
|
116 |
+
self.pos_embed = model.pos_embed
|
117 |
+
self.pos_drop = model.pos_drop
|
118 |
+
self.patch_drop = model.patch_drop
|
119 |
+
self.norm_pre = model.norm_pre
|
120 |
+
|
121 |
+
self.blocks = model.blocks
|
122 |
+
self.norm = model.norm
|
123 |
+
|
124 |
+
self.encoder_channels = self.embed_dim
|
125 |
+
self.encoder_reduction = self.patch_size[0]
|
126 |
+
self.block_size = block_size if block_size is not None else self.encoder_reduction
|
127 |
+
|
128 |
+
if norm == "bn":
|
129 |
+
norm_layer = nn.BatchNorm2d
|
130 |
+
elif norm == "ln":
|
131 |
+
norm_layer = nn.LayerNorm
|
132 |
+
else:
|
133 |
+
norm_layer = _get_norm_layer(model)
|
134 |
+
|
135 |
+
if act == "relu":
|
136 |
+
activation = nn.ReLU(inplace=True)
|
137 |
+
elif act == "gelu":
|
138 |
+
activation = nn.GELU()
|
139 |
+
else:
|
140 |
+
activation = _get_activation(model)
|
141 |
+
|
142 |
+
if self.block_size < self.encoder_reduction:
|
143 |
+
assert self.block_size == self.encoder_reduction // 2, f"Block size should be half of the encoder reduction, but got {self.block_size} and {self.encoder_reduction}."
|
144 |
+
self.refiner = ConvUpsample(
|
145 |
+
in_channels=self.encoder_channels,
|
146 |
+
out_channels=self.encoder_channels,
|
147 |
+
norm_layer=norm_layer,
|
148 |
+
activation=activation,
|
149 |
+
)
|
150 |
+
elif self.block_size > self.encoder_reduction:
|
151 |
+
assert self.block_size == self.encoder_reduction * 2, f"Block size should be double of the encoder reduction, but got {self.block_size} and {self.encoder_reduction}."
|
152 |
+
self.refiner = ConvDownsample(
|
153 |
+
in_channels=self.encoder_channels,
|
154 |
+
out_channels=self.encoder_channels,
|
155 |
+
norm_layer=norm_layer,
|
156 |
+
activation=activation,
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
self.refiner = ConvRefine(
|
160 |
+
in_channels=self.encoder_channels,
|
161 |
+
out_channels=self.encoder_channels,
|
162 |
+
norm_layer=norm_layer,
|
163 |
+
activation=activation,
|
164 |
+
)
|
165 |
+
|
166 |
+
self.refiner_channels = self.encoder_channels
|
167 |
+
self.refiner_reduction = self.block_size
|
168 |
+
|
169 |
+
self.decoder = nn.Identity()
|
170 |
+
self.decoder_channels = self.refiner_channels
|
171 |
+
self.reduction = self.refiner_reduction
|
172 |
+
|
173 |
+
# Adjust the positional embedding to match the new input size
|
174 |
+
self._adjust_pos_embed()
|
175 |
+
|
176 |
+
def _adjust_pos_embed(self) -> Tensor:
|
177 |
+
"""
|
178 |
+
Adjust the positional embedding to match the spatial resolution of the feature map.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
orig_h, orig_w: The original spatial resolution of the image.
|
182 |
+
new_h, new_w: The new spatial resolution of the image.
|
183 |
+
"""
|
184 |
+
self.pos_embed = nn.Parameter(self._interpolate_pos_embed(self.pretrain_size[0], self.pretrain_size[1], self.input_size[0], self.input_size[1]), requires_grad=self.num_vpt == 0)
|
185 |
+
|
186 |
+
def _interpolate_pos_embed(self, orig_h: int, orig_w: int, new_h: int, new_w: int) -> Tensor:
|
187 |
+
"""
|
188 |
+
Interpolate the positional embedding to match the spatial resolution of the feature map.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
orig_h, orig_w: The original spatial resolution of the image.
|
192 |
+
new_h, new_w: The new spatial resolution of the image.
|
193 |
+
"""
|
194 |
+
if (orig_h, orig_w) == (new_h, new_w):
|
195 |
+
return self.pos_embed # (1, (h * w + 1), d)
|
196 |
+
|
197 |
+
orig_h_patches, orig_w_patches = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
|
198 |
+
new_h_patches, new_w_patches = new_h // self.patch_size[0], new_w // self.patch_size[1]
|
199 |
+
class_pos_embed, patch_pos_embed = self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :]
|
200 |
+
patch_pos_embed = rearrange(patch_pos_embed, "1 (h w) d -> 1 d h w", h=orig_h_patches, w=orig_w_patches)
|
201 |
+
patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_h_patches, new_w_patches), mode="bicubic", antialias=True)
|
202 |
+
patch_pos_embed = rearrange(patch_pos_embed, "1 d h w -> 1 (h w) d")
|
203 |
+
pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
204 |
+
return pos_embed
|
205 |
+
|
206 |
+
def train(self, mode: bool = True):
|
207 |
+
if self.num_vpt > 0 and mode:
|
208 |
+
self.patch_embed.eval()
|
209 |
+
self.pos_drop.eval()
|
210 |
+
self.patch_drop.eval()
|
211 |
+
self.norm_pre.eval()
|
212 |
+
|
213 |
+
self.blocks.eval()
|
214 |
+
self.norm.eval()
|
215 |
+
|
216 |
+
for idx in range(self.num_layers):
|
217 |
+
getattr(self, f"vpt_drop_{idx}").train()
|
218 |
+
|
219 |
+
self.refiner.train()
|
220 |
+
self.decoder.train()
|
221 |
+
|
222 |
+
else:
|
223 |
+
for module in self.children():
|
224 |
+
module.train(mode)
|
225 |
+
|
226 |
+
def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
|
227 |
+
vpt = getattr(self, f"vpt_{layer}").unsqueeze(0).expand(batch_size, -1, -1).to(device) # (batch_size, num_vpt, embed_dim)
|
228 |
+
vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
|
229 |
+
|
230 |
+
return vpt
|
231 |
+
|
232 |
+
def _forward_patch_embed(self, x: Tensor) -> Tensor:
|
233 |
+
# This step performs 1) embed x into patches; 2) append the class token; 3) add positional embeddings.
|
234 |
+
assert len(x.shape) == 4, f"Expected input to have shape (batch_size, 3, height, width), but got {x.shape}"
|
235 |
+
batch_size, _, height, width = x.shape
|
236 |
+
|
237 |
+
# Step 1: Embed x into patches
|
238 |
+
x = self.patch_embed(x) # (b, h * w, d)
|
239 |
+
|
240 |
+
# Step 2: Append the class token
|
241 |
+
cls_token = self.cls_token.expand(batch_size, 1, -1)
|
242 |
+
x = torch.cat([cls_token, x], dim=1)
|
243 |
+
|
244 |
+
# Step 3: Add positional embeddings
|
245 |
+
pos_embed = self._interpolate_pos_embed(orig_h=self.input_size[0], orig_w=self.input_size[1], new_h=height, new_w=width).expand(batch_size, -1, -1)
|
246 |
+
x = self.pos_drop(x + pos_embed)
|
247 |
+
return x
|
248 |
+
|
249 |
+
def _forward_vpt(self, x: Tensor, idx: int) -> Tensor:
|
250 |
+
batch_size = x.shape[0]
|
251 |
+
device = x.device
|
252 |
+
|
253 |
+
# Assemble
|
254 |
+
vpt = self._prepare_vpt(idx, batch_size, device)
|
255 |
+
x = torch.cat([
|
256 |
+
x[:, :1, :], # class token
|
257 |
+
vpt,
|
258 |
+
x[:, 1:, :] # patches
|
259 |
+
], dim=1)
|
260 |
+
|
261 |
+
# Forward
|
262 |
+
x = self.blocks[idx](x)
|
263 |
+
|
264 |
+
# Disassemble
|
265 |
+
x = torch.cat([
|
266 |
+
x[:, :1, :], # class token
|
267 |
+
x[:, 1 + self.num_vpt:, :] # patches
|
268 |
+
], dim=1)
|
269 |
+
|
270 |
+
return x
|
271 |
+
|
272 |
+
def _forward(self, x: Tensor, idx: int) -> Tensor:
|
273 |
+
x = self.blocks[idx](x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
def encode(self, x: Tensor) -> Tensor:
|
277 |
+
orig_h, orig_w = x.shape[-2:]
|
278 |
+
num_patches_h, num_patches_w = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
|
279 |
+
|
280 |
+
x = self._forward_patch_embed(x)
|
281 |
+
x = self.patch_drop(x)
|
282 |
+
x = self.norm_pre(x)
|
283 |
+
|
284 |
+
for idx in range(self.num_layers):
|
285 |
+
x = self._forward_vpt(x, idx) if self.num_vpt > 0 else self._forward(x, idx)
|
286 |
+
|
287 |
+
x = self.norm(x)
|
288 |
+
x = x[:, 1:, :]
|
289 |
+
x = rearrange(x, "b (h w) d -> b d h w", h=num_patches_h, w=num_patches_w)
|
290 |
+
return x
|
291 |
+
|
292 |
+
def refine(self, x: Tensor) -> Tensor:
|
293 |
+
return self.refiner(x)
|
294 |
+
|
295 |
+
def decode(self, x: Tensor) -> Tensor:
|
296 |
+
return self.decoder(x)
|
297 |
+
|
298 |
+
def forward(self, x: Tensor) -> Tensor:
|
299 |
+
x = self.encode(x)
|
300 |
+
x = self.refine(x)
|
301 |
+
x = self.decode(x)
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
def _vit(
|
306 |
+
model_name: str,
|
307 |
+
block_size: Optional[int] = None,
|
308 |
+
num_vpt: int = 32,
|
309 |
+
vpt_drop: float = 0.0,
|
310 |
+
input_size: Optional[Tuple[int, int]] = None,
|
311 |
+
norm: str = "none",
|
312 |
+
act: str = "none"
|
313 |
+
) -> ViT:
|
314 |
+
model = ViT(
|
315 |
+
model_name=model_name,
|
316 |
+
block_size=block_size,
|
317 |
+
num_vpt=num_vpt,
|
318 |
+
vpt_drop=vpt_drop,
|
319 |
+
input_size=input_size,
|
320 |
+
norm=norm,
|
321 |
+
act=act
|
322 |
+
)
|
323 |
+
return model
|
models/utils/__init__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from typing import Optional
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from .utils import _init_weights, interpolate_pos_embed
|
6 |
+
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3, Conv2dLayerNorm
|
7 |
+
from .refine import ConvRefine, LightConvRefine, LighterConvRefine
|
8 |
+
from .downsample import ConvDownsample, LightConvDownsample, LighterConvDownsample
|
9 |
+
from .upsample import ConvUpsample, LightConvUpsample, LighterConvUpsample
|
10 |
+
from .multi_scale import MultiScale
|
11 |
+
from .blocks import ConvAdapter, ViTAdapter
|
12 |
+
|
13 |
+
|
14 |
+
def _get_norm_layer(model: nn.Module) -> Optional[nn.Module]:
|
15 |
+
for module in model.modules():
|
16 |
+
if isinstance(module, nn.BatchNorm2d):
|
17 |
+
return nn.BatchNorm2d
|
18 |
+
elif isinstance(module, nn.GroupNorm):
|
19 |
+
num_groups = module.num_groups
|
20 |
+
return partial(nn.GroupNorm, num_groups=num_groups)
|
21 |
+
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
|
22 |
+
return Conv2dLayerNorm
|
23 |
+
return None
|
24 |
+
|
25 |
+
|
26 |
+
def _get_activation(model: nn.Module) -> Optional[nn.Module]:
|
27 |
+
for module in model.modules():
|
28 |
+
if isinstance(module, nn.BatchNorm2d):
|
29 |
+
return nn.ReLU(inplace=True)
|
30 |
+
elif isinstance(module, nn.GroupNorm):
|
31 |
+
return nn.ReLU(inplace=True)
|
32 |
+
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
|
33 |
+
return nn.GELU()
|
34 |
+
return nn.GELU()
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
__all__ = [
|
39 |
+
"_init_weights", "_check_norm_layer", "_check_activation",
|
40 |
+
"conv1x1",
|
41 |
+
"conv3x3",
|
42 |
+
"Conv2dLayerNorm",
|
43 |
+
"interpolate_pos_embed",
|
44 |
+
"DepthSeparableConv2d",
|
45 |
+
"ConvRefine",
|
46 |
+
"LightConvRefine",
|
47 |
+
"LighterConvRefine",
|
48 |
+
"ConvDownsample",
|
49 |
+
"LightConvDownsample",
|
50 |
+
"LighterConvDownsample",
|
51 |
+
"ConvUpsample",
|
52 |
+
"LightConvUpsample",
|
53 |
+
"LighterConvUpsample",
|
54 |
+
"MultiScale",
|
55 |
+
"ConvAdapter", "ViTAdapter",
|
56 |
+
]
|
models/utils/blocks.py
ADDED
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
|
7 |
+
from typing import Callable, Optional, Sequence, Tuple, Union, List, List
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from .utils import _init_weights, _make_ntuple, _log_api_usage_once
|
11 |
+
|
12 |
+
|
13 |
+
def conv3x3(
|
14 |
+
in_channels: int,
|
15 |
+
out_channels: int,
|
16 |
+
stride: int = 1,
|
17 |
+
groups: int = 1,
|
18 |
+
dilation: int = 1,
|
19 |
+
bias: bool = True,
|
20 |
+
) -> nn.Conv2d:
|
21 |
+
"""3x3 convolution with padding"""
|
22 |
+
conv = nn.Conv2d(
|
23 |
+
in_channels,
|
24 |
+
out_channels,
|
25 |
+
kernel_size=3,
|
26 |
+
stride=stride,
|
27 |
+
padding=dilation,
|
28 |
+
groups=groups,
|
29 |
+
bias=bias,
|
30 |
+
dilation=dilation,
|
31 |
+
)
|
32 |
+
conv.apply(_init_weights)
|
33 |
+
return conv
|
34 |
+
|
35 |
+
|
36 |
+
def conv1x1(
|
37 |
+
in_channels: int,
|
38 |
+
out_channels: int,
|
39 |
+
stride: int = 1,
|
40 |
+
bias: bool = True,
|
41 |
+
) -> nn.Conv2d:
|
42 |
+
"""1x1 convolution"""
|
43 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=bias)
|
44 |
+
conv.apply(_init_weights)
|
45 |
+
return conv
|
46 |
+
|
47 |
+
|
48 |
+
class DepthSeparableConv2d(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
in_channels: int,
|
52 |
+
out_channels: int,
|
53 |
+
kernel_size: int,
|
54 |
+
stride: int = 1,
|
55 |
+
padding: int = 0,
|
56 |
+
dilation: int = 1,
|
57 |
+
bias: bool = True,
|
58 |
+
padding_mode: str = "zeros",
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
# Depthwise convolution: one filter per input channel.
|
62 |
+
self.depthwise = nn.Conv2d(
|
63 |
+
in_channels=in_channels,
|
64 |
+
out_channels=in_channels,
|
65 |
+
kernel_size=kernel_size,
|
66 |
+
stride=stride,
|
67 |
+
padding=padding,
|
68 |
+
dilation=dilation,
|
69 |
+
groups=in_channels,
|
70 |
+
bias=bias,
|
71 |
+
padding_mode=padding_mode
|
72 |
+
)
|
73 |
+
# Pointwise convolution: combine the features across channels.
|
74 |
+
self.pointwise = nn.Conv2d(
|
75 |
+
in_channels=in_channels,
|
76 |
+
out_channels=out_channels,
|
77 |
+
kernel_size=1,
|
78 |
+
stride=1,
|
79 |
+
padding=0,
|
80 |
+
dilation=1,
|
81 |
+
groups=1,
|
82 |
+
bias=bias,
|
83 |
+
padding_mode=padding_mode
|
84 |
+
)
|
85 |
+
self.apply(_init_weights)
|
86 |
+
|
87 |
+
def forward(self, x: Tensor) -> Tensor:
|
88 |
+
return self.pointwise(self.depthwise(x))
|
89 |
+
|
90 |
+
|
91 |
+
class SEBlock(nn.Module):
|
92 |
+
def __init__(self, channels: int, reduction: int = 16):
|
93 |
+
super().__init__()
|
94 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
95 |
+
self.fc = nn.Sequential(
|
96 |
+
nn.Linear(channels, channels // reduction, bias=False),
|
97 |
+
nn.ReLU(inplace=True),
|
98 |
+
nn.Linear(channels // reduction, channels, bias=False),
|
99 |
+
nn.Sigmoid()
|
100 |
+
)
|
101 |
+
self.apply(_init_weights)
|
102 |
+
|
103 |
+
def forward(self, x: Tensor) -> Tensor:
|
104 |
+
B, C, _, _ = x.shape
|
105 |
+
y = self.avg_pool(x).view(B, C)
|
106 |
+
y = self.fc(y).view(B, C, 1, 1)
|
107 |
+
return x * y
|
108 |
+
|
109 |
+
|
110 |
+
class BasicBlock(nn.Module):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
in_channels: int,
|
114 |
+
out_channels: int,
|
115 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
116 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
117 |
+
groups: int = 1,
|
118 |
+
) -> None:
|
119 |
+
super().__init__()
|
120 |
+
assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
|
121 |
+
assert in_channels % groups == 0, f"Expected in_channels to be divisible by groups, but got {in_channels} % {groups}"
|
122 |
+
assert out_channels % groups == 0, f"Expected out_channels to be divisible by groups, but got {out_channels} % {groups}"
|
123 |
+
self.grouped_conv = groups > 1
|
124 |
+
self.conv1 = conv3x3(
|
125 |
+
in_channels=in_channels,
|
126 |
+
out_channels=out_channels,
|
127 |
+
stride=1,
|
128 |
+
bias=not norm_layer,
|
129 |
+
groups=groups,
|
130 |
+
)
|
131 |
+
if self.grouped_conv:
|
132 |
+
self.conv1_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
|
133 |
+
|
134 |
+
self.norm1 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
135 |
+
self.act1 = activation
|
136 |
+
|
137 |
+
self.conv2 = conv3x3(
|
138 |
+
in_channels=out_channels,
|
139 |
+
out_channels=out_channels,
|
140 |
+
stride=1,
|
141 |
+
bias=not norm_layer,
|
142 |
+
groups=groups,
|
143 |
+
)
|
144 |
+
if self.grouped_conv:
|
145 |
+
self.conv2_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
|
146 |
+
|
147 |
+
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
148 |
+
self.act2 = activation
|
149 |
+
|
150 |
+
if in_channels != out_channels:
|
151 |
+
self.downsample = nn.Sequential(
|
152 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
153 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
self.downsample = nn.Identity()
|
157 |
+
|
158 |
+
self.apply(_init_weights)
|
159 |
+
|
160 |
+
def forward(self, x: Tensor) -> Tensor:
|
161 |
+
identity = x
|
162 |
+
|
163 |
+
out = self.conv1(x)
|
164 |
+
out = self.conv1_1x1(out) if self.grouped_conv else out
|
165 |
+
out = self.norm1(out)
|
166 |
+
out = self.act1(out)
|
167 |
+
|
168 |
+
out = self.conv2(out)
|
169 |
+
out = self.conv2_1x1(out) if self.grouped_conv else out
|
170 |
+
out = self.norm2(out)
|
171 |
+
|
172 |
+
out += self.downsample(identity)
|
173 |
+
out = self.act2(out)
|
174 |
+
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
class LightBasicBlock(nn.Module):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
in_channels: int,
|
182 |
+
out_channels: int,
|
183 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
184 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
185 |
+
) -> None:
|
186 |
+
super().__init__()
|
187 |
+
self.conv1 = DepthSeparableConv2d(
|
188 |
+
in_channels=in_channels,
|
189 |
+
out_channels=out_channels,
|
190 |
+
kernel_size=3,
|
191 |
+
stride=1,
|
192 |
+
padding=1,
|
193 |
+
bias=not norm_layer,
|
194 |
+
)
|
195 |
+
self.norm1 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
196 |
+
self.act1 = activation
|
197 |
+
|
198 |
+
self.conv2 = DepthSeparableConv2d(
|
199 |
+
in_channels=out_channels,
|
200 |
+
out_channels=out_channels,
|
201 |
+
kernel_size=3,
|
202 |
+
stride=1,
|
203 |
+
padding=1,
|
204 |
+
bias=not norm_layer,
|
205 |
+
)
|
206 |
+
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
207 |
+
self.act2 = activation
|
208 |
+
|
209 |
+
if in_channels != out_channels:
|
210 |
+
self.downsample = nn.Sequential(
|
211 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
212 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
self.downsample = nn.Identity()
|
216 |
+
|
217 |
+
self.apply(_init_weights)
|
218 |
+
|
219 |
+
def forward(self, x: Tensor) -> Tensor:
|
220 |
+
identity = x
|
221 |
+
|
222 |
+
out = self.conv1(x)
|
223 |
+
out = self.norm1(out)
|
224 |
+
out = self.act1(out)
|
225 |
+
|
226 |
+
out = self.conv2(out)
|
227 |
+
out = self.norm2(out)
|
228 |
+
|
229 |
+
out += self.downsample(identity)
|
230 |
+
out = self.act2(out)
|
231 |
+
|
232 |
+
return out
|
233 |
+
|
234 |
+
|
235 |
+
class Bottleneck(nn.Module):
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
in_channels: int,
|
239 |
+
out_channels: int,
|
240 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
241 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
242 |
+
groups: int = 1,
|
243 |
+
base_width: int = 64,
|
244 |
+
expansion: float = 2.0,
|
245 |
+
) -> None:
|
246 |
+
super().__init__()
|
247 |
+
assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
|
248 |
+
assert expansion > 0, f"Expected expansion to be greater than 0, but got {expansion}"
|
249 |
+
assert base_width > 0, f"Expected base_width to be greater than 0, but got {base_width}"
|
250 |
+
bottleneck_channels = int(in_channels * (base_width / 64.0) * expansion)
|
251 |
+
assert bottleneck_channels % groups == 0, f"Expected bottleneck_channels to be divisible by groups, but got {bottleneck_channels} % {groups}"
|
252 |
+
self.grouped_conv = groups > 1
|
253 |
+
self.expansion, self.base_width = expansion, base_width
|
254 |
+
|
255 |
+
self.conv_in = conv1x1(in_channels, bottleneck_channels, stride=1, bias=not norm_layer)
|
256 |
+
self.norm_in = norm_layer(bottleneck_channels)
|
257 |
+
self.act_in = activation
|
258 |
+
|
259 |
+
self.se_in = SEBlock(bottleneck_channels) if bottleneck_channels > in_channels else nn.Identity()
|
260 |
+
|
261 |
+
self.conv_block_1 = nn.Sequential(
|
262 |
+
conv3x3(
|
263 |
+
in_channels=bottleneck_channels,
|
264 |
+
out_channels=bottleneck_channels,
|
265 |
+
stride=1,
|
266 |
+
groups=groups,
|
267 |
+
bias=not norm_layer
|
268 |
+
),
|
269 |
+
conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
|
270 |
+
norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
|
271 |
+
activation,
|
272 |
+
)
|
273 |
+
|
274 |
+
self.conv_block_2 = nn.Sequential(
|
275 |
+
conv3x3(
|
276 |
+
in_channels=bottleneck_channels,
|
277 |
+
out_channels=bottleneck_channels,
|
278 |
+
stride=1,
|
279 |
+
groups=groups,
|
280 |
+
bias=not norm_layer
|
281 |
+
),
|
282 |
+
conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
|
283 |
+
norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
|
284 |
+
activation,
|
285 |
+
)
|
286 |
+
|
287 |
+
self.conv_out = conv1x1(bottleneck_channels, out_channels, stride=1, bias=not norm_layer)
|
288 |
+
self.norm_out = norm_layer(out_channels)
|
289 |
+
self.act_out = activation
|
290 |
+
self.se_out = SEBlock(out_channels) if out_channels > bottleneck_channels else nn.Identity()
|
291 |
+
|
292 |
+
if in_channels != out_channels:
|
293 |
+
self.downsample = nn.Sequential(
|
294 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
295 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.downsample = nn.Identity()
|
299 |
+
|
300 |
+
self.apply(_init_weights)
|
301 |
+
|
302 |
+
def forward(self, x: Tensor) -> Tensor:
|
303 |
+
identity = x
|
304 |
+
|
305 |
+
# expand
|
306 |
+
out = self.conv_in(x)
|
307 |
+
out = self.norm_in(out)
|
308 |
+
out = self.act_in(out)
|
309 |
+
out = self.se_in(out)
|
310 |
+
|
311 |
+
# conv
|
312 |
+
out = self.conv_block_1(out)
|
313 |
+
out = self.conv_block_2(out)
|
314 |
+
|
315 |
+
# reduce
|
316 |
+
out = self.conv_out(out)
|
317 |
+
out = self.norm_out(out)
|
318 |
+
out = self.se_out(out)
|
319 |
+
|
320 |
+
out += self.downsample(identity)
|
321 |
+
out = self.act_out(out)
|
322 |
+
return out
|
323 |
+
|
324 |
+
|
325 |
+
class ConvASPP(nn.Module):
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
in_channels: int,
|
329 |
+
out_channels: int,
|
330 |
+
dilations: List[int] = [1, 2, 4],
|
331 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
332 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
333 |
+
groups: int = 1,
|
334 |
+
base_width: int = 64,
|
335 |
+
expansion: float = 2.0,
|
336 |
+
) -> None:
|
337 |
+
super().__init__()
|
338 |
+
assert isinstance(groups, int) and groups > 0, f"Expected groups to be a positive integer, but got {groups}"
|
339 |
+
assert expansion > 0, f"Expected expansion to be greater than 0, but got {expansion}"
|
340 |
+
assert base_width > 0, f"Expected base_width to be greater than 0, but got {base_width}"
|
341 |
+
bottleneck_channels = int(in_channels * (base_width / 64.0) * expansion)
|
342 |
+
assert bottleneck_channels % groups == 0, f"Expected bottleneck_channels to be divisible by groups, but got {bottleneck_channels} % {groups}"
|
343 |
+
self.expansion, self.base_width = expansion, base_width
|
344 |
+
|
345 |
+
self.conv_in = conv1x1(in_channels, bottleneck_channels, stride=1, bias=not norm_layer)
|
346 |
+
self.norm_in = norm_layer(bottleneck_channels)
|
347 |
+
self.act_in = activation
|
348 |
+
|
349 |
+
conv_blocks = [nn.Sequential(
|
350 |
+
conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer),
|
351 |
+
norm_layer(bottleneck_channels),
|
352 |
+
activation
|
353 |
+
)]
|
354 |
+
|
355 |
+
for dilation in dilations:
|
356 |
+
conv_blocks.append(nn.Sequential(
|
357 |
+
conv3x3(
|
358 |
+
in_channels=bottleneck_channels,
|
359 |
+
out_channels=bottleneck_channels,
|
360 |
+
stride=1,
|
361 |
+
groups=groups,
|
362 |
+
dilation=dilation,
|
363 |
+
bias=not norm_layer
|
364 |
+
),
|
365 |
+
conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer) if groups > 1 else nn.Identity(),
|
366 |
+
norm_layer(bottleneck_channels) if norm_layer else nn.Identity(),
|
367 |
+
activation
|
368 |
+
))
|
369 |
+
|
370 |
+
self.convs = nn.ModuleList(conv_blocks)
|
371 |
+
|
372 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
373 |
+
self.conv_avg = conv1x1(bottleneck_channels, bottleneck_channels, stride=1, bias=not norm_layer)
|
374 |
+
self.norm_avg = norm_layer(bottleneck_channels)
|
375 |
+
self.act_avg = activation
|
376 |
+
|
377 |
+
self.se = SEBlock(bottleneck_channels * (len(dilations) + 2))
|
378 |
+
|
379 |
+
self.conv_out = conv1x1(bottleneck_channels * (len(dilations) + 2), out_channels, stride=1, bias=not norm_layer)
|
380 |
+
self.norm_out = norm_layer(out_channels)
|
381 |
+
self.act_out = activation
|
382 |
+
|
383 |
+
if in_channels != out_channels:
|
384 |
+
self.downsample = nn.Sequential(
|
385 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
386 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
self.downsample = nn.Identity()
|
390 |
+
|
391 |
+
self.apply(_init_weights)
|
392 |
+
|
393 |
+
def forward(self, x: Tensor) -> Tensor:
|
394 |
+
height, width = x.shape[-2:]
|
395 |
+
identity = x
|
396 |
+
|
397 |
+
# expand
|
398 |
+
out = self.conv_in(x)
|
399 |
+
out = self.norm_in(out)
|
400 |
+
out = self.act_in(out)
|
401 |
+
|
402 |
+
outs = []
|
403 |
+
for conv in self.convs:
|
404 |
+
outs.append(conv(out))
|
405 |
+
|
406 |
+
avg = self.avgpool(out)
|
407 |
+
avg = self.conv_avg(avg)
|
408 |
+
avg = self.norm_avg(avg)
|
409 |
+
avg = self.act_avg(avg) # (B, C, 1, 1)
|
410 |
+
avg = avg.repeat(1, 1, height, width)
|
411 |
+
|
412 |
+
outs = torch.cat([*outs, avg], dim=1) # (B, C * (len(dilations) + 2), H, W)
|
413 |
+
outs = self.se(outs)
|
414 |
+
|
415 |
+
# reduce
|
416 |
+
outs = self.conv_out(outs)
|
417 |
+
outs = self.norm_out(outs)
|
418 |
+
|
419 |
+
outs += self.downsample(identity)
|
420 |
+
outs = self.act_out(outs)
|
421 |
+
return outs
|
422 |
+
|
423 |
+
|
424 |
+
class ViTBlock(nn.Module):
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
embed_dim: int,
|
428 |
+
num_heads: int = 8,
|
429 |
+
dropout: float = 0.0,
|
430 |
+
mlp_ratio: float = 4.0,
|
431 |
+
) -> None:
|
432 |
+
super().__init__()
|
433 |
+
assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}"
|
434 |
+
self.embed_dim, self.num_heads = embed_dim, num_heads
|
435 |
+
self.dropout, self.mlp_ratio = dropout, mlp_ratio
|
436 |
+
|
437 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
438 |
+
self.attn = nn.MultiheadAttention(
|
439 |
+
embed_dim=embed_dim,
|
440 |
+
num_heads=num_heads,
|
441 |
+
dropout=dropout,
|
442 |
+
batch_first=True
|
443 |
+
)
|
444 |
+
|
445 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
446 |
+
self.mlp = nn.Sequential(
|
447 |
+
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
|
448 |
+
nn.GELU(),
|
449 |
+
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
450 |
+
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
|
451 |
+
nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
452 |
+
)
|
453 |
+
self.apply(_init_weights)
|
454 |
+
|
455 |
+
def forward(self, x: Tensor) -> Tensor:
|
456 |
+
assert len(x.shape) == 3, f"Expected input to have shape (B, N, C), but got {x.shape}"
|
457 |
+
x = x + self.attn(self.norm1(x))
|
458 |
+
x = x + self.mlp(self.norm2(x))
|
459 |
+
return x
|
460 |
+
|
461 |
+
|
462 |
+
class Conv2dLayerNorm(nn.Sequential):
|
463 |
+
"""
|
464 |
+
Layer normalization applied in a convolutional fashion.
|
465 |
+
"""
|
466 |
+
def __init__(self, dim: int) -> None:
|
467 |
+
super().__init__(
|
468 |
+
Rearrange("B C H W -> B H W C"),
|
469 |
+
nn.LayerNorm(dim),
|
470 |
+
Rearrange("B H W C -> B C H W")
|
471 |
+
)
|
472 |
+
self.apply(_init_weights)
|
473 |
+
|
474 |
+
|
475 |
+
class CvTAttention(nn.Module):
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
embed_dim: int,
|
479 |
+
num_heads: int = 8,
|
480 |
+
dropout: float = 0.0,
|
481 |
+
q_stride: int = 1, # controls downsampling rate
|
482 |
+
kv_stride: int = 1,
|
483 |
+
) -> None:
|
484 |
+
super().__init__()
|
485 |
+
assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}"
|
486 |
+
self.embed_dim, self.num_heads, self.dim_head = embed_dim, num_heads, embed_dim // num_heads
|
487 |
+
self.scale = self.dim_head ** -0.5
|
488 |
+
self.q_stride, self.kv_stride = q_stride, kv_stride
|
489 |
+
|
490 |
+
self.attend = nn.Softmax(dim=-1)
|
491 |
+
self.dropout = nn.Dropout(dropout)
|
492 |
+
|
493 |
+
self.to_q = DepthSeparableConv2d(
|
494 |
+
in_channels=embed_dim,
|
495 |
+
out_channels=embed_dim,
|
496 |
+
kernel_size=3,
|
497 |
+
stride=q_stride,
|
498 |
+
padding=1,
|
499 |
+
bias=False
|
500 |
+
)
|
501 |
+
self.to_k = DepthSeparableConv2d(
|
502 |
+
in_channels=embed_dim,
|
503 |
+
out_channels=embed_dim,
|
504 |
+
kernel_size=3,
|
505 |
+
stride=kv_stride,
|
506 |
+
padding=1,
|
507 |
+
bias=False
|
508 |
+
)
|
509 |
+
self.to_v = DepthSeparableConv2d(
|
510 |
+
in_channels=embed_dim,
|
511 |
+
out_channels=embed_dim,
|
512 |
+
kernel_size=3,
|
513 |
+
stride=kv_stride,
|
514 |
+
padding=1,
|
515 |
+
bias=False
|
516 |
+
)
|
517 |
+
|
518 |
+
self.to_out = nn.Sequential(
|
519 |
+
conv1x1(embed_dim, embed_dim, stride=1),
|
520 |
+
nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
521 |
+
)
|
522 |
+
|
523 |
+
self.apply(_init_weights)
|
524 |
+
|
525 |
+
def forward(self, x: Tensor) -> Tensor:
|
526 |
+
assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
|
527 |
+
assert x.shape[1] == self.embed_dim, f"Expected input to have embedding dimension {self.embed_dim}, but got {x.shape[1]}"
|
528 |
+
|
529 |
+
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
530 |
+
B, _, H, W = q.shape
|
531 |
+
q, k, v = map(lambda t: rearrange(t, "B (num_heads head_dim) H W -> (B num_heads) (H W) head_dim", num_heads=self.num_heads), (q, k, v))
|
532 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
533 |
+
attn = self.attend(attn)
|
534 |
+
attn = self.dropout(attn)
|
535 |
+
|
536 |
+
out = attn @ v
|
537 |
+
out = rearrange(out, "(B num_heads) (H W) head_dim -> B (num_heads head_dim) H W", B=B, H=H, W=W, num_heads=self.num_heads)
|
538 |
+
out = self.to_out(out)
|
539 |
+
|
540 |
+
return out
|
541 |
+
|
542 |
+
|
543 |
+
class CvTBlock(nn.Module):
|
544 |
+
"""
|
545 |
+
Implement convolutional vision transformer block.
|
546 |
+
"""
|
547 |
+
def __init__(
|
548 |
+
self,
|
549 |
+
embed_dim: int,
|
550 |
+
num_heads: int = 8,
|
551 |
+
dropout: float = 0.0,
|
552 |
+
mlp_ratio: float = 4.0,
|
553 |
+
q_stride: int = 1,
|
554 |
+
kv_stride: int = 1,
|
555 |
+
) -> None:
|
556 |
+
super().__init__()
|
557 |
+
assert embed_dim % num_heads == 0, f"Embedding dimension {embed_dim} should be divisible by number of heads {num_heads}."
|
558 |
+
self.embed_dim, self.num_heads = embed_dim, num_heads
|
559 |
+
|
560 |
+
self.norm1 = Conv2dLayerNorm(embed_dim)
|
561 |
+
self.attn = CvTAttention(embed_dim, num_heads, dropout, q_stride, kv_stride)
|
562 |
+
|
563 |
+
self.pool = nn.AvgPool2d(kernel_size=q_stride, stride=q_stride) if q_stride > 1 else nn.Identity()
|
564 |
+
|
565 |
+
self.norm2 = Conv2dLayerNorm(embed_dim)
|
566 |
+
self.mlp = nn.Sequential(
|
567 |
+
nn.Conv2d(embed_dim, int(embed_dim * mlp_ratio), kernel_size=1),
|
568 |
+
nn.GELU(),
|
569 |
+
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
570 |
+
nn.Conv2d(int(embed_dim * mlp_ratio), embed_dim, kernel_size=1),
|
571 |
+
nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
572 |
+
)
|
573 |
+
|
574 |
+
def forward(self, x: Tensor) -> Tensor:
|
575 |
+
x = self.pool(x) + self.attn(self.norm1(x))
|
576 |
+
x = x + self.mlp(self.norm2(x))
|
577 |
+
return x
|
578 |
+
|
579 |
+
|
580 |
+
class ConvAdapter(nn.Module):
|
581 |
+
def __init__(
|
582 |
+
self,
|
583 |
+
in_channels: int,
|
584 |
+
bottleneck_channels: int = 16,
|
585 |
+
) -> None:
|
586 |
+
super().__init__()
|
587 |
+
assert in_channels > 0, f"Expected input_channels to be greater than 0, but got {in_channels}"
|
588 |
+
assert bottleneck_channels > 0, f"Expected bottleneck_channels to be greater than 0, but got {bottleneck_channels}"
|
589 |
+
|
590 |
+
self.adapter = nn.Sequential(
|
591 |
+
nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1),
|
592 |
+
nn.GELU(),
|
593 |
+
nn.Conv2d(bottleneck_channels, in_channels, kernel_size=1),
|
594 |
+
)
|
595 |
+
nn.init.zeros_(self.adapter[2].weight)
|
596 |
+
nn.init.zeros_(self.adapter[2].bias)
|
597 |
+
|
598 |
+
def forward(self, x: Tensor) -> Tensor:
|
599 |
+
assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
|
600 |
+
return x + self.adapter(x)
|
601 |
+
|
602 |
+
|
603 |
+
class ViTAdapter(nn.Module):
|
604 |
+
def __init__(self, input_dim, bottleneck_dim):
|
605 |
+
super().__init__()
|
606 |
+
self.adapter = nn.Sequential(
|
607 |
+
nn.Linear(input_dim, bottleneck_dim),
|
608 |
+
nn.GELU(), # ViT中常用GELU作为激活函数
|
609 |
+
nn.Linear(bottleneck_dim, input_dim)
|
610 |
+
)
|
611 |
+
nn.init.zeros_(self.adapter[2].weight)
|
612 |
+
nn.init.zeros_(self.adapter[2].bias)
|
613 |
+
|
614 |
+
def forward(self, x: Tensor) -> Tensor:
|
615 |
+
assert len(x.shape) == 3, f"Expected input to have shape (B, N, C), but got {x.shape}"
|
616 |
+
return x + self.adapter(x)
|
617 |
+
|
models/utils/carafe.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def carafe_forward(
|
6 |
+
features: torch.Tensor,
|
7 |
+
masks: torch.Tensor,
|
8 |
+
kernel_size: int,
|
9 |
+
group_size: int,
|
10 |
+
scale_factor: int
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Pure-PyTorch implementation of the CARAFE upsampling operator.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
features (Tensor): Input feature map of shape (N, C, H, W).
|
17 |
+
masks (Tensor): Reassembly kernel weights of shape
|
18 |
+
(N, kernel_size*kernel_size*group_size, H_out, W_out),
|
19 |
+
where H_out = H*scale_factor and W_out = W*scale_factor.
|
20 |
+
kernel_size (int): The spatial size of the reassembly kernel.
|
21 |
+
group_size (int): The group size to divide channels. Must divide C.
|
22 |
+
scale_factor (int): The upsampling factor.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor).
|
26 |
+
"""
|
27 |
+
N, C, H, W = features.size()
|
28 |
+
out_H, out_W = H * scale_factor, W * scale_factor
|
29 |
+
num_channels = C // group_size # channels per group
|
30 |
+
|
31 |
+
# Reshape features to (N, group_size, num_channels, H, W)
|
32 |
+
features = features.view(N, group_size, num_channels, H, W)
|
33 |
+
# Merge batch and group dims for unfolding
|
34 |
+
features_reshaped = features.view(N * group_size, num_channels, H, W)
|
35 |
+
# Extract local patches; use padding so that output spatial dims match input
|
36 |
+
patches = F.unfold(features_reshaped, kernel_size=kernel_size,
|
37 |
+
padding=(kernel_size - 1) // 2)
|
38 |
+
# patches shape: (N*group_size, num_channels*kernel_size*kernel_size, H*W)
|
39 |
+
# Reshape to (N, group_size, num_channels, kernel_size*kernel_size, H, W)
|
40 |
+
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W)
|
41 |
+
# Flatten spatial dimensions: now (N, group_size, num_channels, kernel_size*kernel_size, H*W)
|
42 |
+
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W)
|
43 |
+
|
44 |
+
# For each output pixel location, determine the corresponding base input index.
|
45 |
+
# For an output coordinate (oh, ow), the corresponding input index is:
|
46 |
+
# h = oh // scale_factor, w = ow // scale_factor, linear index = h * W + w.
|
47 |
+
device = features.device
|
48 |
+
# Create coordinate indices for output
|
49 |
+
h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') # (out_H,)
|
50 |
+
w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') # (out_W,)
|
51 |
+
# Form a 2D grid of base indices (shape: out_H x out_W)
|
52 |
+
h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) # (out_H, out_W)
|
53 |
+
w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) # (out_H, out_W)
|
54 |
+
base_idx = (h_idx * W + w_idx).view(-1) # (out_H*out_W,)
|
55 |
+
|
56 |
+
# Expand base_idx so that it can index the last dimension of patches:
|
57 |
+
# Desired shape for gathering: (N, group_size, num_channels, kernel_size*kernel_size, out_H*out_W)
|
58 |
+
base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1)
|
59 |
+
# Gather patches corresponding to each output location
|
60 |
+
gathered_patches = torch.gather(patches, -1, base_idx)
|
61 |
+
# Reshape gathered patches to (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
|
62 |
+
gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W)
|
63 |
+
|
64 |
+
# Reshape masks to separate groups.
|
65 |
+
# Expected mask shape: (N, kernel_size*kernel_size*group_size, out_H, out_W)
|
66 |
+
# Reshape to: (N, group_size, kernel_size*kernel_size, out_H, out_W)
|
67 |
+
masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W)
|
68 |
+
# For multiplication, add a channel dimension so that masks shape becomes
|
69 |
+
# (N, group_size, 1, kernel_size*kernel_size, out_H, out_W)
|
70 |
+
masks = masks.unsqueeze(2)
|
71 |
+
# Expand masks to match gathered_patches: (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
|
72 |
+
masks = masks.expand(-1, -1, num_channels, -1, -1, -1)
|
73 |
+
|
74 |
+
# Multiply patches with masks and sum over the kernel dimension.
|
75 |
+
# This yields the reassembled features for each output location.
|
76 |
+
out = (gathered_patches * masks).sum(dim=3) # shape: (N, group_size, num_channels, out_H, out_W)
|
77 |
+
# Reshape back to (N, C, out_H, out_W)
|
78 |
+
out = out.view(N, C, out_H, out_W)
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
class CARAFE(nn.Module):
|
83 |
+
"""
|
84 |
+
CARAFE: Content-Aware ReAssembly of Features
|
85 |
+
|
86 |
+
This PyTorch module implements the CARAFE upsampling operator in pure Python.
|
87 |
+
Given an input feature map and its corresponding reassembly masks, the module
|
88 |
+
reassembles features from local patches to produce a higher-resolution output.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
kernel_size (int): Reassembly kernel size.
|
92 |
+
group_size (int): Group size for channel grouping (must divide number of channels).
|
93 |
+
scale_factor (int): Upsample ratio.
|
94 |
+
"""
|
95 |
+
def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
|
96 |
+
super(CARAFE, self).__init__()
|
97 |
+
self.kernel_size = kernel_size
|
98 |
+
self.group_size = group_size
|
99 |
+
self.scale_factor = scale_factor
|
100 |
+
|
101 |
+
def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
|
102 |
+
return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor)
|
103 |
+
|
104 |
+
|
105 |
+
class CARAFEPack(nn.Module):
|
106 |
+
"""
|
107 |
+
A unified package of the CARAFE upsampler that contains:
|
108 |
+
1) A channel compressor.
|
109 |
+
2) A content encoder that predicts reassembly masks.
|
110 |
+
3) The CARAFE operator.
|
111 |
+
|
112 |
+
This is modeled after the official CARAFE package.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
channels (int): Number of input feature channels.
|
116 |
+
scale_factor (int): Upsample ratio.
|
117 |
+
up_kernel (int): Kernel size for the CARAFE operator.
|
118 |
+
up_group (int): Group size for the CARAFE operator.
|
119 |
+
encoder_kernel (int): Kernel size of the content encoder.
|
120 |
+
encoder_dilation (int): Dilation rate for the content encoder.
|
121 |
+
compressed_channels (int): Output channels for the channel compressor.
|
122 |
+
"""
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
channels: int,
|
126 |
+
scale_factor: int,
|
127 |
+
up_kernel: int = 5,
|
128 |
+
up_group: int = 1,
|
129 |
+
encoder_kernel: int = 3,
|
130 |
+
encoder_dilation: int = 1,
|
131 |
+
compressed_channels: int = 64
|
132 |
+
):
|
133 |
+
super(CARAFEPack, self).__init__()
|
134 |
+
self.channels = channels
|
135 |
+
self.scale_factor = scale_factor
|
136 |
+
self.up_kernel = up_kernel
|
137 |
+
self.up_group = up_group
|
138 |
+
self.encoder_kernel = encoder_kernel
|
139 |
+
self.encoder_dilation = encoder_dilation
|
140 |
+
self.compressed_channels = compressed_channels
|
141 |
+
|
142 |
+
# Compress input channels.
|
143 |
+
self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1)
|
144 |
+
# Predict reassembly masks.
|
145 |
+
self.content_encoder = nn.Conv2d(
|
146 |
+
compressed_channels,
|
147 |
+
up_kernel * up_kernel * up_group * scale_factor * scale_factor,
|
148 |
+
kernel_size=encoder_kernel,
|
149 |
+
padding=int((encoder_kernel - 1) * encoder_dilation / 2),
|
150 |
+
dilation=encoder_dilation
|
151 |
+
)
|
152 |
+
# Initialize weights (using Xavier for conv layers).
|
153 |
+
nn.init.xavier_uniform_(self.channel_compressor.weight)
|
154 |
+
nn.init.xavier_uniform_(self.content_encoder.weight)
|
155 |
+
if self.channel_compressor.bias is not None:
|
156 |
+
nn.init.constant_(self.channel_compressor.bias, 0)
|
157 |
+
if self.content_encoder.bias is not None:
|
158 |
+
nn.init.constant_(self.content_encoder.bias, 0)
|
159 |
+
|
160 |
+
def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor:
|
161 |
+
"""
|
162 |
+
Normalize and reshape the mask.
|
163 |
+
Applies pixel shuffle to upsample the predicted kernel weights and then
|
164 |
+
applies softmax normalization across the kernel dimension.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
mask (Tensor): Predicted mask of shape (N, out_channels, H, W).
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale).
|
171 |
+
"""
|
172 |
+
# Pixel shuffle to rearrange and upsample the mask.
|
173 |
+
mask = F.pixel_shuffle(mask, self.scale_factor)
|
174 |
+
N, mask_c, H, W = mask.size()
|
175 |
+
# Determine the number of channels per kernel
|
176 |
+
mask_channel = mask_c // (self.up_kernel ** 2)
|
177 |
+
mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W)
|
178 |
+
mask = F.softmax(mask, dim=2)
|
179 |
+
mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous()
|
180 |
+
return mask
|
181 |
+
|
182 |
+
def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
183 |
+
return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor)
|
184 |
+
|
185 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
186 |
+
compressed_x = self.channel_compressor(x)
|
187 |
+
mask = self.content_encoder(compressed_x)
|
188 |
+
mask = self.kernel_normalizer(mask)
|
189 |
+
out = self.feature_reassemble(x, mask)
|
190 |
+
return out
|
191 |
+
|
192 |
+
|
193 |
+
# === Example Usage ===
|
194 |
+
if __name__ == '__main__':
|
195 |
+
# Create dummy input: batch size 2, 64 channels, 32x32 spatial resolution.
|
196 |
+
x = torch.randn(2, 64, 32, 32).cuda() # assuming GPU available
|
197 |
+
# Define CARAFEPack with upsample ratio 2.
|
198 |
+
# For example, use kernel size 5, group size 1.
|
199 |
+
upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda()
|
200 |
+
# Get upsampled feature map.
|
201 |
+
out = upsampler(x)
|
202 |
+
print("Input shape: ", x.shape)
|
203 |
+
print("Output shape:", out.shape) # Expected shape: (2, 64, 64, 64)
|
models/utils/downsample.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3
|
6 |
+
from .utils import _init_weights
|
7 |
+
|
8 |
+
|
9 |
+
class ConvDownsample(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_channels: int,
|
13 |
+
out_channels: int,
|
14 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
15 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
16 |
+
groups: int = 1,
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
assert isinstance(groups, int) and groups > 0, f"Number of groups should be an integer greater than 0, but got {groups}."
|
20 |
+
assert in_channels % groups == 0, f"Number of input channels {in_channels} should be divisible by number of groups {groups}."
|
21 |
+
assert out_channels % groups == 0, f"Number of output channels {out_channels} should be divisible by number of groups {groups}."
|
22 |
+
self.grouped_conv = groups > 1
|
23 |
+
|
24 |
+
# conv1 is used for downsampling
|
25 |
+
# self.conv1 = nn.Conv2d(
|
26 |
+
# in_channels=in_channels,
|
27 |
+
# out_channels=in_channels,
|
28 |
+
# kernel_size=2,
|
29 |
+
# stride=2,
|
30 |
+
# padding=0,
|
31 |
+
# bias=not norm_layer,
|
32 |
+
# groups=groups,
|
33 |
+
# )
|
34 |
+
# if self.grouped_conv:
|
35 |
+
# self.conv1_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
|
36 |
+
self.conv1 = nn.AvgPool2d(kernel_size=2, stride=2) # downsample by 2
|
37 |
+
if self.grouped_conv:
|
38 |
+
self.conv1_1x1 = nn.Identity()
|
39 |
+
|
40 |
+
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
41 |
+
self.act1 = activation
|
42 |
+
|
43 |
+
self.conv2 = conv3x3(
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=in_channels,
|
46 |
+
stride=1,
|
47 |
+
groups=groups,
|
48 |
+
bias=not norm_layer,
|
49 |
+
)
|
50 |
+
if self.grouped_conv:
|
51 |
+
self.conv2_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
|
52 |
+
|
53 |
+
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
54 |
+
self.act2 = activation
|
55 |
+
|
56 |
+
self.conv3 = conv3x3(
|
57 |
+
in_channels=in_channels,
|
58 |
+
out_channels=out_channels,
|
59 |
+
stride=1,
|
60 |
+
groups=groups,
|
61 |
+
bias=not norm_layer,
|
62 |
+
)
|
63 |
+
if self.grouped_conv:
|
64 |
+
self.conv3_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
|
65 |
+
|
66 |
+
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
67 |
+
self.act3 = activation
|
68 |
+
|
69 |
+
self.downsample = nn.Sequential(
|
70 |
+
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
|
71 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
72 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
73 |
+
)
|
74 |
+
|
75 |
+
self.apply(_init_weights)
|
76 |
+
|
77 |
+
def forward(self, x: Tensor) -> Tensor:
|
78 |
+
identity = x
|
79 |
+
|
80 |
+
# downsample
|
81 |
+
out = self.conv1(x)
|
82 |
+
out = self.conv1_1x1(out) if self.grouped_conv else out
|
83 |
+
out = self.norm1(out)
|
84 |
+
out = self.act1(out)
|
85 |
+
|
86 |
+
out = self.conv2(out)
|
87 |
+
out = self.conv2_1x1(out) if self.grouped_conv else out
|
88 |
+
out = self.norm2(out)
|
89 |
+
out = self.act2(out)
|
90 |
+
|
91 |
+
out = self.conv3(out)
|
92 |
+
out = self.conv3_1x1(out) if self.grouped_conv else out
|
93 |
+
out = self.norm3(out)
|
94 |
+
|
95 |
+
# shortcut
|
96 |
+
out += self.downsample(identity)
|
97 |
+
out = self.act3(out)
|
98 |
+
return out
|
99 |
+
|
100 |
+
|
101 |
+
class LightConvDownsample(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
in_channels: int,
|
105 |
+
out_channels: int,
|
106 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
107 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
108 |
+
) -> None:
|
109 |
+
super().__init__()
|
110 |
+
self.conv1 = DepthSeparableConv2d(
|
111 |
+
in_channels=in_channels,
|
112 |
+
out_channels=in_channels,
|
113 |
+
kernel_size=2,
|
114 |
+
stride=2,
|
115 |
+
padding=0,
|
116 |
+
bias=not norm_layer,
|
117 |
+
)
|
118 |
+
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
119 |
+
self.act1 = activation
|
120 |
+
|
121 |
+
self.conv2 = DepthSeparableConv2d(
|
122 |
+
in_channels=in_channels,
|
123 |
+
out_channels=out_channels,
|
124 |
+
kernel_size=3,
|
125 |
+
stride=1,
|
126 |
+
padding=1,
|
127 |
+
bias=not norm_layer,
|
128 |
+
)
|
129 |
+
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
130 |
+
self.act2 = activation
|
131 |
+
|
132 |
+
self.conv3 = DepthSeparableConv2d(
|
133 |
+
in_channels=out_channels,
|
134 |
+
out_channels=out_channels,
|
135 |
+
kernel_size=3,
|
136 |
+
stride=1,
|
137 |
+
padding=1,
|
138 |
+
bias=not norm_layer,
|
139 |
+
)
|
140 |
+
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
141 |
+
self.act3 = activation
|
142 |
+
|
143 |
+
self.downsample = nn.Sequential(
|
144 |
+
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
|
145 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
146 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
147 |
+
)
|
148 |
+
|
149 |
+
self.apply(_init_weights)
|
150 |
+
|
151 |
+
def forward(self, x: Tensor) -> Tensor:
|
152 |
+
identity = x
|
153 |
+
|
154 |
+
# downsample
|
155 |
+
out = self.conv1(x)
|
156 |
+
out = self.norm1(out)
|
157 |
+
out = self.act1(out)
|
158 |
+
|
159 |
+
# refine 1
|
160 |
+
out = self.conv2(out)
|
161 |
+
out = self.norm2(out)
|
162 |
+
out = self.act2(out)
|
163 |
+
|
164 |
+
# refine 2
|
165 |
+
out = self.conv3(out)
|
166 |
+
out = self.norm3(out)
|
167 |
+
|
168 |
+
# shortcut
|
169 |
+
out += self.downsample(identity)
|
170 |
+
out = self.act3(out)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class LighterConvDownsample(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
in_channels: int,
|
178 |
+
out_channels: int,
|
179 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
180 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
181 |
+
) -> None:
|
182 |
+
super().__init__()
|
183 |
+
self.conv1 = DepthSeparableConv2d(
|
184 |
+
in_channels=in_channels,
|
185 |
+
out_channels=in_channels,
|
186 |
+
kernel_size=2,
|
187 |
+
stride=2,
|
188 |
+
padding=0,
|
189 |
+
bias=not norm_layer,
|
190 |
+
)
|
191 |
+
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
192 |
+
self.act1 = activation
|
193 |
+
|
194 |
+
self.conv2 = conv3x3(
|
195 |
+
in_channels=in_channels,
|
196 |
+
out_channels=in_channels,
|
197 |
+
stride=1,
|
198 |
+
groups=in_channels,
|
199 |
+
bias=not norm_layer,
|
200 |
+
)
|
201 |
+
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
202 |
+
self.act2 = activation
|
203 |
+
|
204 |
+
self.conv3 = conv1x1(
|
205 |
+
in_channels=in_channels,
|
206 |
+
out_channels=out_channels,
|
207 |
+
stride=1,
|
208 |
+
bias=not norm_layer,
|
209 |
+
)
|
210 |
+
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
211 |
+
self.act3 = activation
|
212 |
+
|
213 |
+
self.downsample = nn.Sequential(
|
214 |
+
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
|
215 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
216 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
217 |
+
)
|
218 |
+
|
219 |
+
def forward(self, x: Tensor) -> Tensor:
|
220 |
+
identity = x
|
221 |
+
|
222 |
+
# downsample
|
223 |
+
out = self.conv1(x)
|
224 |
+
out = self.norm1(out)
|
225 |
+
out = self.act1(out)
|
226 |
+
|
227 |
+
# refine, depthwise conv
|
228 |
+
out = self.conv2(out)
|
229 |
+
out = self.norm2(out)
|
230 |
+
out = self.act2(out)
|
231 |
+
|
232 |
+
# refine, pointwise conv
|
233 |
+
out = self.conv3(out)
|
234 |
+
out = self.norm3(out)
|
235 |
+
|
236 |
+
# shortcut
|
237 |
+
out += self.downsample(identity)
|
238 |
+
out = self.act3(out)
|
239 |
+
return out
|
models/utils/multi_scale.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from typing import List
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
from .blocks import conv3x3, conv1x1, Conv2dLayerNorm, _init_weights
|
7 |
+
|
8 |
+
|
9 |
+
class MultiScale(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
channels: int,
|
13 |
+
scales: List[int],
|
14 |
+
heads: int = 8,
|
15 |
+
groups: int = 1,
|
16 |
+
mlp_ratio: float = 4.0,
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
assert channels > 0, "channels should be a positive integer"
|
20 |
+
assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all([scale > 0 for scale in scales]), "scales should be a list or tuple of positive integers"
|
21 |
+
assert heads > 0 and channels % heads == 0, "heads should be a positive integer and channels should be divisible by heads"
|
22 |
+
assert groups > 0 and channels % groups == 0, "groups should be a positive integer and channels should be divisible by groups"
|
23 |
+
scales = sorted(scales)
|
24 |
+
self.scales = scales
|
25 |
+
self.num_scales = len(scales) + 1 # +1 for the original feature map
|
26 |
+
self.heads = heads
|
27 |
+
self.groups = groups
|
28 |
+
|
29 |
+
# modules that generate multi-scale feature maps
|
30 |
+
self.scale_0 = nn.Sequential(
|
31 |
+
conv1x1(channels, channels, stride=1, bias=False),
|
32 |
+
Conv2dLayerNorm(channels),
|
33 |
+
nn.GELU(),
|
34 |
+
)
|
35 |
+
for scale in scales:
|
36 |
+
setattr(self, f"conv_{scale}", nn.Sequential(
|
37 |
+
conv3x3(
|
38 |
+
in_channels=channels,
|
39 |
+
out_channels=channels,
|
40 |
+
stride=1,
|
41 |
+
groups=groups,
|
42 |
+
dilation=scale,
|
43 |
+
bias=False,
|
44 |
+
),
|
45 |
+
conv1x1(channels, channels, stride=1, bias=False) if groups > 1 else nn.Identity(),
|
46 |
+
Conv2dLayerNorm(channels),
|
47 |
+
nn.GELU(),
|
48 |
+
))
|
49 |
+
|
50 |
+
# modules that fuse multi-scale feature maps
|
51 |
+
self.norm_attn = Conv2dLayerNorm(channels)
|
52 |
+
self.pos_embed = nn.Parameter(torch.randn(1, self.num_scales + 1, channels, 1, 1) / channels ** 0.5)
|
53 |
+
self.to_q = conv1x1(channels, channels, stride=1, bias=False)
|
54 |
+
self.to_k = conv1x1(channels, channels, stride=1, bias=False)
|
55 |
+
self.to_v = conv1x1(channels, channels, stride=1, bias=False)
|
56 |
+
|
57 |
+
self.scale = (channels // heads) ** -0.5
|
58 |
+
|
59 |
+
self.attend = nn.Softmax(dim=-1)
|
60 |
+
|
61 |
+
self.to_out = conv1x1(channels, channels, stride=1)
|
62 |
+
|
63 |
+
# modules that refine multi-scale feature maps
|
64 |
+
self.norm_mlp = Conv2dLayerNorm(channels)
|
65 |
+
self.mlp = nn.Sequential(
|
66 |
+
conv1x1(channels, channels * mlp_ratio, stride=1),
|
67 |
+
nn.GELU(),
|
68 |
+
conv1x1(channels * mlp_ratio, channels, stride=1),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.apply(_init_weights)
|
72 |
+
|
73 |
+
def _forward_attn(self, x: Tensor) -> Tensor:
|
74 |
+
assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
|
75 |
+
x = [self.scale_0(x)] + [getattr(self, f"conv_{scale}")(x) for scale in self.scales]
|
76 |
+
|
77 |
+
x = torch.stack(x, dim=1) # (B, S, C, H, W)
|
78 |
+
x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1) # (B, S+1, C, H, W)
|
79 |
+
x = x + self.pos_embed # (B, S+1, C, H, W)
|
80 |
+
|
81 |
+
x = rearrange(x, "B S C H W -> (B S) C H W") # (B*(S+1), C, H, W)
|
82 |
+
x = self.norm_attn(x) # (B*(S+1), C, H, W)
|
83 |
+
x = rearrange(x, "(B S) C H W -> B S C H W", S=self.num_scales + 1) # (B, S+1, C, H, W)
|
84 |
+
|
85 |
+
q = self.to_q(x[:, 0]) # (B, C, H, W)
|
86 |
+
k = self.to_k(rearrange(x, "B S C H W -> (B S) C H W"))
|
87 |
+
v = self.to_v(rearrange(x, "B S C H W -> (B S) C H W"))
|
88 |
+
|
89 |
+
q = rearrange(q, "B (h d) H W -> B h H W 1 d", h=self.heads)
|
90 |
+
k = rearrange(k, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)
|
91 |
+
v = rearrange(v, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)
|
92 |
+
|
93 |
+
attn = q @ k.transpose(-2, -1) * self.scale # (B, h, H, W, 1, S+1)
|
94 |
+
attn = self.attend(attn) # (B, h, H, W, 1, S+1)
|
95 |
+
out = attn @ v # (B, h, H, W, 1, d)
|
96 |
+
|
97 |
+
out = rearrange(out, "B h H W 1 d -> B (h d) H W") # (B, C, H, W)
|
98 |
+
|
99 |
+
out = self.to_out(out) # (B, C, H, W)
|
100 |
+
return out
|
101 |
+
|
102 |
+
def _forward_mlp(self, x: Tensor) -> Tensor:
|
103 |
+
assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
|
104 |
+
x = self.norm_mlp(x)
|
105 |
+
x = self.mlp(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
def forward(self, x: Tensor) -> Tensor:
|
109 |
+
x = x + self._forward_attn(x)
|
110 |
+
x = x + self._forward_mlp(x)
|
111 |
+
return x
|
112 |
+
|
models/utils/refine.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from .utils import _init_weights
|
5 |
+
from .blocks import BasicBlock, LightBasicBlock, conv1x1, conv3x3
|
6 |
+
|
7 |
+
|
8 |
+
class ConvRefine(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
in_channels: int,
|
12 |
+
out_channels: int,
|
13 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
14 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
15 |
+
groups: int = 1,
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.refine = BasicBlock(
|
19 |
+
in_channels=in_channels,
|
20 |
+
out_channels=out_channels,
|
21 |
+
norm_layer=norm_layer,
|
22 |
+
activation=activation,
|
23 |
+
groups=groups,
|
24 |
+
)
|
25 |
+
self.apply(_init_weights)
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return self.refine(x)
|
29 |
+
|
30 |
+
|
31 |
+
class LightConvRefine(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
in_channels: int,
|
35 |
+
out_channels: int,
|
36 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
37 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.refine = LightBasicBlock(
|
41 |
+
in_channels=in_channels,
|
42 |
+
out_channels=out_channels,
|
43 |
+
norm_layer=norm_layer,
|
44 |
+
activation=activation,
|
45 |
+
)
|
46 |
+
self.apply(_init_weights)
|
47 |
+
|
48 |
+
def forward(self, x: Tensor) -> Tensor:
|
49 |
+
return self.refine(x)
|
50 |
+
|
51 |
+
|
52 |
+
class LighterConvRefine(nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
in_channels: int,
|
56 |
+
out_channels: int,
|
57 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
58 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
# depthwise separable convolution
|
62 |
+
self.conv1 = conv3x3(
|
63 |
+
in_channels=in_channels,
|
64 |
+
out_channels=in_channels,
|
65 |
+
stride=1,
|
66 |
+
groups=in_channels,
|
67 |
+
bias=not norm_layer,
|
68 |
+
)
|
69 |
+
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
|
70 |
+
self.act1 = activation
|
71 |
+
|
72 |
+
self.conv2 = conv1x1(
|
73 |
+
in_channels=in_channels,
|
74 |
+
out_channels=out_channels,
|
75 |
+
stride=1,
|
76 |
+
bias=not norm_layer,
|
77 |
+
)
|
78 |
+
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
|
79 |
+
self.act2 = activation
|
80 |
+
|
81 |
+
if in_channels != out_channels:
|
82 |
+
self.downsample = nn.Sequential(
|
83 |
+
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
|
84 |
+
norm_layer(out_channels) if norm_layer else nn.Identity(),
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
self.downsample = nn.Identity()
|
88 |
+
|
89 |
+
self.apply(_init_weights)
|
90 |
+
|
91 |
+
def forward(self, x: Tensor) -> Tensor:
|
92 |
+
identity = x
|
93 |
+
|
94 |
+
out = self.conv1(x)
|
95 |
+
out = self.norm1(out)
|
96 |
+
out = self.act1(out)
|
97 |
+
|
98 |
+
out = self.conv2(out)
|
99 |
+
out = self.norm2(out)
|
100 |
+
|
101 |
+
out += self.downsample(identity)
|
102 |
+
out = self.act2(out)
|
103 |
+
return out
|
models/utils/upsample.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, Tensor
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from .utils import _init_weights
|
8 |
+
from .refine import ConvRefine, LightConvRefine, LighterConvRefine
|
9 |
+
|
10 |
+
|
11 |
+
class ConvUpsample(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_channels: int,
|
15 |
+
out_channels: int,
|
16 |
+
scale_factor: int = 2,
|
17 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
18 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
19 |
+
groups: int = 1,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
|
23 |
+
self.scale_factor = scale_factor
|
24 |
+
self.upsample = partial(
|
25 |
+
F.interpolate,
|
26 |
+
scale_factor=scale_factor,
|
27 |
+
mode="bilinear",
|
28 |
+
align_corners=False,
|
29 |
+
recompute_scale_factor=False,
|
30 |
+
antialias=False,
|
31 |
+
) if scale_factor > 1 else nn.Identity()
|
32 |
+
|
33 |
+
self.refine = ConvRefine(
|
34 |
+
in_channels=in_channels,
|
35 |
+
out_channels=out_channels,
|
36 |
+
norm_layer=norm_layer,
|
37 |
+
activation=activation,
|
38 |
+
groups=groups,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.apply(_init_weights)
|
42 |
+
|
43 |
+
def forward(self, x: Tensor) -> Tensor:
|
44 |
+
x = self.upsample(x)
|
45 |
+
x = self.refine(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class LightConvUpsample(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
in_channels: int,
|
53 |
+
out_channels: int,
|
54 |
+
scale_factor: int = 2,
|
55 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
56 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
57 |
+
) -> None:
|
58 |
+
super().__init__()
|
59 |
+
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
|
60 |
+
self.scale_factor = scale_factor
|
61 |
+
self.upsample = partial(
|
62 |
+
F.interpolate,
|
63 |
+
scale_factor=scale_factor,
|
64 |
+
mode="bilinear",
|
65 |
+
align_corners=False,
|
66 |
+
recompute_scale_factor=False,
|
67 |
+
antialias=False,
|
68 |
+
) if scale_factor > 1 else nn.Identity()
|
69 |
+
|
70 |
+
self.refine = LightConvRefine(
|
71 |
+
in_channels=in_channels,
|
72 |
+
out_channels=out_channels,
|
73 |
+
norm_layer=norm_layer,
|
74 |
+
activation=activation,
|
75 |
+
)
|
76 |
+
|
77 |
+
self.apply(_init_weights)
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor:
|
80 |
+
x = self.upsample(x)
|
81 |
+
x = self.refine(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class LighterConvUpsample(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
in_channels: int,
|
89 |
+
out_channels: int,
|
90 |
+
scale_factor: int = 2,
|
91 |
+
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
|
92 |
+
activation: nn.Module = nn.ReLU(inplace=True),
|
93 |
+
) -> None:
|
94 |
+
super().__init__()
|
95 |
+
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}"
|
96 |
+
self.scale_factor = scale_factor
|
97 |
+
self.upsample = partial(
|
98 |
+
F.interpolate,
|
99 |
+
scale_factor=scale_factor,
|
100 |
+
mode="bilinear",
|
101 |
+
align_corners=False,
|
102 |
+
recompute_scale_factor=False,
|
103 |
+
antialias=False,
|
104 |
+
) if scale_factor > 1 else nn.Identity()
|
105 |
+
|
106 |
+
self.refine = LighterConvRefine(
|
107 |
+
in_channels=in_channels,
|
108 |
+
out_channels=out_channels,
|
109 |
+
norm_layer=norm_layer,
|
110 |
+
activation=activation,
|
111 |
+
)
|
112 |
+
|
113 |
+
self.apply(_init_weights)
|
114 |
+
|
115 |
+
def forward(self, x: Tensor) -> Tensor:
|
116 |
+
x = self.upsample(x)
|
117 |
+
x = self.refine(x)
|
118 |
+
return x
|
models/utils/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Tuple, Any, Optional, Union
|
5 |
+
from types import FunctionType
|
6 |
+
from itertools import repeat
|
7 |
+
from collections.abc import Iterable
|
8 |
+
|
9 |
+
|
10 |
+
def _log_api_usage_once(obj: Any) -> None:
|
11 |
+
|
12 |
+
"""
|
13 |
+
Logs API usage(module and name) within an organization.
|
14 |
+
In a large ecosystem, it's often useful to track the PyTorch and
|
15 |
+
TorchVision APIs usage. This API provides the similar functionality to the
|
16 |
+
logging module in the Python stdlib. It can be used for debugging purpose
|
17 |
+
to log which methods are used and by default it is inactive, unless the user
|
18 |
+
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
|
19 |
+
Please note it is triggered only once for the same API call within a process.
|
20 |
+
It does not collect any data from open-source users since it is no-op by default.
|
21 |
+
For more information, please refer to
|
22 |
+
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
|
23 |
+
* Logging policy: https://github.com/pytorch/vision/issues/5052;
|
24 |
+
|
25 |
+
Args:
|
26 |
+
obj (class instance or method): an object to extract info from.
|
27 |
+
"""
|
28 |
+
module = obj.__module__
|
29 |
+
if not module.startswith("torchvision"):
|
30 |
+
module = f"torchvision.internal.{module}"
|
31 |
+
name = obj.__class__.__name__
|
32 |
+
if isinstance(obj, FunctionType):
|
33 |
+
name = obj.__name__
|
34 |
+
torch._C._log_api_usage_once(f"{module}.{name}")
|
35 |
+
|
36 |
+
|
37 |
+
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
38 |
+
"""
|
39 |
+
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
40 |
+
Otherwise, we will make a tuple of length n, all with value of x.
|
41 |
+
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x (Any): input value
|
45 |
+
n (int): length of the resulting tuple
|
46 |
+
"""
|
47 |
+
if isinstance(x, Iterable):
|
48 |
+
return tuple(x)
|
49 |
+
return tuple(repeat(x, n))
|
50 |
+
|
51 |
+
|
52 |
+
def _init_weights(model: nn.Module) -> None:
|
53 |
+
for m in model.modules():
|
54 |
+
if isinstance(m, nn.Conv2d):
|
55 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
56 |
+
if m.bias is not None:
|
57 |
+
nn.init.constant_(m.bias, 0.)
|
58 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
|
59 |
+
nn.init.constant_(m.weight, 1.)
|
60 |
+
if m.bias is not None:
|
61 |
+
nn.init.constant_(m.bias, 0.)
|
62 |
+
elif isinstance(m, nn.Linear):
|
63 |
+
nn.init.normal_(m.weight, std=0.01)
|
64 |
+
if m.bias is not None:
|
65 |
+
nn.init.constant_(m.bias, 0.)
|
66 |
+
|
67 |
+
|
68 |
+
def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor:
|
69 |
+
assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}."
|
70 |
+
return F.interpolate(
|
71 |
+
pos_embed.unsqueeze(0),
|
72 |
+
size=size,
|
73 |
+
scale_factor=scale_factor,
|
74 |
+
mode="bicubic",
|
75 |
+
align_corners=False,
|
76 |
+
antialias=True,
|
77 |
+
).squeeze(0)
|