Spaces:
Running
on
Zero
Running
on
Zero
Upload 9 files
Browse files- app.py +243 -0
- ckpts/dino_spair_0300.pth +3 -0
- model_utils/__pycache__/extractor_dino.cpython-310.pyc +0 -0
- model_utils/__pycache__/projection_network.cpython-310.pyc +0 -0
- model_utils/__pycache__/resnet.cpython-310.pyc +0 -0
- model_utils/extractor_dino.py +356 -0
- model_utils/projection_network.py +167 -0
- model_utils/resnet.py +518 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
from matplotlib import cm
|
9 |
+
|
10 |
+
from model_utils.extractor_dino import ViTExtractor
|
11 |
+
from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork
|
12 |
+
|
13 |
+
def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'):
|
14 |
+
filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST
|
15 |
+
original_width, original_height = img.size
|
16 |
+
original_channels = len(img.getbands())
|
17 |
+
if not edge:
|
18 |
+
canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8)
|
19 |
+
if original_channels == 1:
|
20 |
+
canvas = np.zeros([target_res, target_res], dtype=np.uint8)
|
21 |
+
if original_height <= original_width:
|
22 |
+
if resize:
|
23 |
+
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
|
24 |
+
width, height = img.size
|
25 |
+
img = np.asarray(img)
|
26 |
+
canvas[(width - height) // 2: (width + height) // 2] = img
|
27 |
+
else:
|
28 |
+
if resize:
|
29 |
+
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
|
30 |
+
width, height = img.size
|
31 |
+
img = np.asarray(img)
|
32 |
+
canvas[:, (height - width) // 2: (height + width) // 2] = img
|
33 |
+
else:
|
34 |
+
if original_height <= original_width:
|
35 |
+
if resize:
|
36 |
+
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
|
37 |
+
width, height = img.size
|
38 |
+
img = np.asarray(img)
|
39 |
+
top_pad = (target_res - height) // 2
|
40 |
+
bottom_pad = target_res - height - top_pad
|
41 |
+
img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge')
|
42 |
+
else:
|
43 |
+
if resize:
|
44 |
+
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
|
45 |
+
width, height = img.size
|
46 |
+
img = np.asarray(img)
|
47 |
+
left_pad = (target_res - width) // 2
|
48 |
+
right_pad = target_res - width - left_pad
|
49 |
+
img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge')
|
50 |
+
canvas = img
|
51 |
+
if to_pil:
|
52 |
+
canvas = Image.fromarray(canvas)
|
53 |
+
return canvas
|
54 |
+
|
55 |
+
# ─── Configuration ───────────────────────────────────────────────
|
56 |
+
num_patches = 30
|
57 |
+
target_res = num_patches * 14
|
58 |
+
ckpt_file = "ckpts/dino_spair_0300.pth"
|
59 |
+
|
60 |
+
# ─── Model setup ─────────────────────────────────────────────────
|
61 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
62 |
+
aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
|
63 |
+
aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
|
64 |
+
aggre_net_dummy = DummyAggregationNetwork()
|
65 |
+
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
|
66 |
+
|
67 |
+
# ─── Feature extraction ──────────────────────────────────────────
|
68 |
+
def get_processed_features_dino(num_patches, img,use_dummy):
|
69 |
+
batch = extractor_vit.preprocess_pil(img)
|
70 |
+
features_dino = extractor_vit.extract_descriptors(batch.cuda(), layer=11, facet='token') \
|
71 |
+
.permute(0,1,3,2) \
|
72 |
+
.reshape(1, -1, num_patches, num_patches)
|
73 |
+
# Project + normalize
|
74 |
+
with torch.no_grad():
|
75 |
+
if use_dummy == "DINOv2":
|
76 |
+
desc = aggre_net_dummy(features_dino)
|
77 |
+
else:
|
78 |
+
desc = aggre_net(features_dino)
|
79 |
+
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
80 |
+
desc = desc / (norms + 1e-8)
|
81 |
+
return desc # shape [1, C, num_patches, num_patches]
|
82 |
+
|
83 |
+
# ─── Similarity computation ───────────────────────────────────────
|
84 |
+
def get_sim(
|
85 |
+
coord: tuple[int,int],
|
86 |
+
feat1: torch.Tensor,
|
87 |
+
feat2: torch.Tensor,
|
88 |
+
img_size: int = target_res
|
89 |
+
) -> np.ndarray:
|
90 |
+
"""
|
91 |
+
Upsamples the DINO features to `img_size`, then computes cosine‐similarity
|
92 |
+
between the feature at `coord` in source and every spatial location in target.
|
93 |
+
"""
|
94 |
+
y, x = coord # row, col
|
95 |
+
|
96 |
+
# Upsample both feature maps to [1, C, img_size, img_size]
|
97 |
+
upsampler = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=False)
|
98 |
+
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
|
99 |
+
trg_ft = upsampler(feat2)
|
100 |
+
|
101 |
+
# Extract the C‐dim vector at the clicked location
|
102 |
+
C = src_ft.size(1)
|
103 |
+
src_vec = src_ft[0, :, y, x].view(1, C, 1, 1) # [1, C, 1, 1]
|
104 |
+
|
105 |
+
# Cosine similarity along channel‐dim
|
106 |
+
cos = nn.CosineSimilarity(dim=1)
|
107 |
+
cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size]
|
108 |
+
return cos_map.cpu().numpy()
|
109 |
+
|
110 |
+
# ─── Drawing helper ───────────────────────────────────────────────
|
111 |
+
def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
|
112 |
+
pil = Image.fromarray(img_arr)
|
113 |
+
draw = ImageDraw.Draw(pil)
|
114 |
+
r = size // 2
|
115 |
+
draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color)
|
116 |
+
return np.array(pil)
|
117 |
+
|
118 |
+
# ─── Feature‐updating callback ───────────────────────────────────
|
119 |
+
def update_features(
|
120 |
+
img: Image,
|
121 |
+
num_patches,
|
122 |
+
use_dummy
|
123 |
+
):
|
124 |
+
"""
|
125 |
+
Given a PIL image, returns:
|
126 |
+
1) the same PIL image (so it can be displayed)
|
127 |
+
2) its DINO descriptor tensor, stored in a gr.State
|
128 |
+
"""
|
129 |
+
if img is None:
|
130 |
+
return None, None, None
|
131 |
+
img = resize(img, target_res=target_res, resize=True, to_pil=True)
|
132 |
+
feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
|
133 |
+
return img, feat.cpu(), Image.fromarray(np.array(img))
|
134 |
+
|
135 |
+
# ─── Click handler ───────────────────────────────────────────────
|
136 |
+
def on_select(
|
137 |
+
source_pil: Image,
|
138 |
+
target_pil: Image,
|
139 |
+
feat1: torch.Tensor,
|
140 |
+
feat2: torch.Tensor,
|
141 |
+
alpha: float,
|
142 |
+
scatter_size: int,
|
143 |
+
or_tgt_img: Image,
|
144 |
+
or_src_img: Image,
|
145 |
+
sel: gr.SelectData
|
146 |
+
):
|
147 |
+
# Convert to numpy arrays
|
148 |
+
src_arr = np.array(or_src_img)
|
149 |
+
tgt_arr = np.array(or_tgt_img)
|
150 |
+
|
151 |
+
# Get click coords (row, col)
|
152 |
+
y, x = sel.index
|
153 |
+
|
154 |
+
src_marked = draw_point(src_arr, y, x, scatter_size)
|
155 |
+
|
156 |
+
# Compute similarity map
|
157 |
+
sim_map = get_sim((x, y), feat1, feat2, img_size=target_res)
|
158 |
+
|
159 |
+
mn, mx = sim_map.min(), sim_map.max()
|
160 |
+
sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12)
|
161 |
+
|
162 |
+
# Build RGBA heatmap
|
163 |
+
heat = cm.viridis(sim_norm) # H×W×4
|
164 |
+
heat[..., 3] = sim_norm * alpha # alpha channel
|
165 |
+
|
166 |
+
# Composite over fresh target
|
167 |
+
tgt_f = tgt_arr.astype(np.float32) / 255.0
|
168 |
+
comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4])
|
169 |
+
overlay = (comp * 255).astype(np.uint8)
|
170 |
+
|
171 |
+
# Draw a red dot at the best match
|
172 |
+
my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape)
|
173 |
+
overlay_marked = draw_point(overlay, mx_, my, scatter_size)
|
174 |
+
|
175 |
+
return src_marked,overlay_marked
|
176 |
+
|
177 |
+
def reload_img(
|
178 |
+
or_src_img: Image,
|
179 |
+
or_tgt_img: Image,
|
180 |
+
):
|
181 |
+
return or_src_img,or_tgt_img
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
# ─── Build Gradio UI ──────────────────────────────────────────────
|
186 |
+
with gr.Blocks() as demo:
|
187 |
+
# Hidden states to hold features
|
188 |
+
feat1_state = gr.State()
|
189 |
+
feat2_state = gr.State()
|
190 |
+
or_tgt_img = gr.State()
|
191 |
+
or_src_img = gr.State()
|
192 |
+
|
193 |
+
# Introduction text box
|
194 |
+
intro_text = gr.Markdown("""
|
195 |
+
## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
|
196 |
+
[Project Page](https://example.com) | [GitHub Repository](https://github.com/example/repo)
|
197 |
+
|
198 |
+
Welcome to the DIY-SC demo!
|
199 |
+
Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
|
200 |
+
You can choose between the DIY-SC (DINOv2) and a DINOv2 feature extractor.
|
201 |
+
""")
|
202 |
+
|
203 |
+
# Image upload / display components
|
204 |
+
with gr.Row():
|
205 |
+
src = gr.Image(interactive=True, type="pil", label="Source Image")
|
206 |
+
tgt = gr.Image(interactive=True, type="pil", label="Target Image")
|
207 |
+
|
208 |
+
# Controls
|
209 |
+
alpha = gr.State(0.7)
|
210 |
+
scatter = gr.State(10)
|
211 |
+
use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor")
|
212 |
+
|
213 |
+
src.input(
|
214 |
+
fn=update_features,
|
215 |
+
inputs=[src, gr.State(num_patches),use_dummy],
|
216 |
+
outputs=[src, feat1_state,or_src_img,]
|
217 |
+
)
|
218 |
+
|
219 |
+
tgt.input(
|
220 |
+
fn=update_features,
|
221 |
+
inputs=[tgt, gr.State(num_patches),use_dummy],
|
222 |
+
outputs=[tgt, feat2_state,or_tgt_img]
|
223 |
+
)
|
224 |
+
|
225 |
+
use_dummy.change(
|
226 |
+
fn=update_features,
|
227 |
+
inputs=[or_src_img, gr.State(num_patches), use_dummy],
|
228 |
+
outputs=[src, feat1_state, or_src_img]
|
229 |
+
)
|
230 |
+
|
231 |
+
use_dummy.change(
|
232 |
+
fn=update_features,
|
233 |
+
inputs=[or_tgt_img, gr.State(num_patches), use_dummy],
|
234 |
+
outputs=[tgt, feat2_state, or_tgt_img]
|
235 |
+
)
|
236 |
+
|
237 |
+
src.select(
|
238 |
+
fn=on_select,
|
239 |
+
inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img],
|
240 |
+
outputs=[src,tgt]
|
241 |
+
)
|
242 |
+
|
243 |
+
demo.launch(share=True)
|
ckpts/dino_spair_0300.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42e9c8a4d27041af7bb8dc5240ea3e83d47e013b72541dc8c5829368f247e705
|
3 |
+
size 2519767
|
model_utils/__pycache__/extractor_dino.cpython-310.pyc
ADDED
Binary file (14.5 kB). View file
|
|
model_utils/__pycache__/projection_network.cpython-310.pyc
ADDED
Binary file (5.36 kB). View file
|
|
model_utils/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (15.4 kB). View file
|
|
model_utils/extractor_dino.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch.nn.modules.utils as nn_utils
|
5 |
+
import math
|
6 |
+
import types
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Union, List, Tuple
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
class ViTExtractor:
|
13 |
+
""" This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
|
14 |
+
We use the following notation in the documentation of the module's methods:
|
15 |
+
B - batch size
|
16 |
+
h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
|
17 |
+
p - patch size of the ViT. either 8 or 16.
|
18 |
+
t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
|
19 |
+
of the input image.
|
20 |
+
d - the embedding dimension in the ViT.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
|
24 |
+
"""
|
25 |
+
:param model_type: A string specifying the type of model to extract from.
|
26 |
+
[dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
|
27 |
+
vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
|
28 |
+
:param stride: stride of first convolution layer. small stride -> higher resolution.
|
29 |
+
:param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
|
30 |
+
should be compatible with model_type.
|
31 |
+
"""
|
32 |
+
self.model_type = model_type
|
33 |
+
self.device = device
|
34 |
+
if model is not None:
|
35 |
+
self.model = model
|
36 |
+
else:
|
37 |
+
self.model = ViTExtractor.create_model(model_type)
|
38 |
+
|
39 |
+
self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
|
40 |
+
self.model.eval()
|
41 |
+
self.model.to(self.device)
|
42 |
+
self.p = self.model.patch_embed.patch_size
|
43 |
+
if type(self.p)==tuple:
|
44 |
+
self.p = self.p[0]
|
45 |
+
self.stride = self.model.patch_embed.proj.stride
|
46 |
+
|
47 |
+
self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
|
48 |
+
self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5)
|
49 |
+
|
50 |
+
self._feats = []
|
51 |
+
self.hook_handlers = []
|
52 |
+
self.load_size = None
|
53 |
+
self.num_patches = None
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def create_model(model_type: str) -> nn.Module:
|
57 |
+
"""
|
58 |
+
:param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 |
|
59 |
+
dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 |
|
60 |
+
vit_base_patch16_224]
|
61 |
+
:return: the model
|
62 |
+
"""
|
63 |
+
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
|
64 |
+
if 'v2' in model_type:
|
65 |
+
model = torch.hub.load('facebookresearch/dinov2', model_type)
|
66 |
+
elif 'dino' in model_type:
|
67 |
+
model = torch.hub.load('facebookresearch/dino:main', model_type)
|
68 |
+
elif 'ibot' in model_type:
|
69 |
+
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
|
70 |
+
temp_state_dict = torch.load("ibot/checkpoint_teacher.pth", map_location="cpu")
|
71 |
+
temp_state_dict = temp_state_dict["state_dict"]
|
72 |
+
# remove `module.` prefix
|
73 |
+
temp_state_dict = {k.replace("module.", ""): v for k, v in temp_state_dict.items()}
|
74 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
75 |
+
temp_state_dict = {k.replace("backbone.", ""): v for k, v in temp_state_dict.items()}
|
76 |
+
msg = model.load_state_dict(temp_state_dict, strict=False)
|
77 |
+
print(msg)
|
78 |
+
else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
|
79 |
+
import timm
|
80 |
+
temp_model = timm.create_model(model_type, pretrained=True)
|
81 |
+
model_type_dict = {
|
82 |
+
'vit_small_patch16_224': 'dino_vits16',
|
83 |
+
'vit_small_patch8_224': 'dino_vits8',
|
84 |
+
'vit_base_patch16_224': 'dino_vitb16',
|
85 |
+
'vit_base_patch8_224': 'dino_vitb8'
|
86 |
+
}
|
87 |
+
model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type])
|
88 |
+
temp_state_dict = temp_model.state_dict()
|
89 |
+
del temp_state_dict['head.weight']
|
90 |
+
del temp_state_dict['head.bias']
|
91 |
+
model.load_state_dict(temp_state_dict)
|
92 |
+
return model
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]):
|
96 |
+
"""
|
97 |
+
Creates a method for position encoding interpolation.
|
98 |
+
:param patch_size: patch size of the model.
|
99 |
+
:param stride_hw: A tuple containing the new height and width stride respectively.
|
100 |
+
:return: the interpolation method
|
101 |
+
"""
|
102 |
+
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
|
103 |
+
npatch = x.shape[1] - 1
|
104 |
+
N = self.pos_embed.shape[1] - 1
|
105 |
+
if npatch == N and w == h:
|
106 |
+
return self.pos_embed
|
107 |
+
class_pos_embed = self.pos_embed[:, 0]
|
108 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
109 |
+
dim = x.shape[-1]
|
110 |
+
# compute number of tokens taking stride into account
|
111 |
+
w0 = 1 + (w - patch_size) // stride_hw[1]
|
112 |
+
h0 = 1 + (h - patch_size) // stride_hw[0]
|
113 |
+
assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and
|
114 |
+
stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}"""
|
115 |
+
# we add a small number to avoid floating point error in the interpolation
|
116 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
117 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
118 |
+
patch_pos_embed = nn.functional.interpolate(
|
119 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
120 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
121 |
+
mode='bicubic',
|
122 |
+
align_corners=False, recompute_scale_factor=False
|
123 |
+
)
|
124 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
125 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
126 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
127 |
+
|
128 |
+
return interpolate_pos_encoding
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
|
132 |
+
"""
|
133 |
+
change resolution of model output by changing the stride of the patch extraction.
|
134 |
+
:param model: the model to change resolution for.
|
135 |
+
:param stride: the new stride parameter.
|
136 |
+
:return: the adjusted model
|
137 |
+
"""
|
138 |
+
patch_size = model.patch_embed.patch_size
|
139 |
+
if type(patch_size) == tuple:
|
140 |
+
patch_size = patch_size[0]
|
141 |
+
if stride == patch_size: # nothing to do
|
142 |
+
return model
|
143 |
+
|
144 |
+
stride = nn_utils._pair(stride)
|
145 |
+
assert all([(patch_size // s_) * s_ == patch_size for s_ in
|
146 |
+
stride]), f'stride {stride} should divide patch_size {patch_size}'
|
147 |
+
|
148 |
+
# fix the stride
|
149 |
+
model.patch_embed.proj.stride = stride
|
150 |
+
# fix the positional encoding code
|
151 |
+
model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
|
152 |
+
return model
|
153 |
+
|
154 |
+
def preprocess(self, image_path: Union[str, Path],
|
155 |
+
load_size: Union[int, Tuple[int, int]] = None, patch_size: int = 14) -> Tuple[torch.Tensor, Image.Image]:
|
156 |
+
"""
|
157 |
+
Preprocesses an image before extraction.
|
158 |
+
:param image_path: path to image to be extracted.
|
159 |
+
:param load_size: optional. Size to resize image before the rest of preprocessing.
|
160 |
+
:return: a tuple containing:
|
161 |
+
(1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
|
162 |
+
(2) the pil image in relevant dimensions
|
163 |
+
"""
|
164 |
+
def divisible_by_num(num, dim):
|
165 |
+
return num * (dim // num)
|
166 |
+
pil_image = Image.open(image_path).convert('RGB')
|
167 |
+
if load_size is not None:
|
168 |
+
pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
|
169 |
+
|
170 |
+
width, height = pil_image.size
|
171 |
+
new_width = divisible_by_num(patch_size, width)
|
172 |
+
new_height = divisible_by_num(patch_size, height)
|
173 |
+
pil_image = pil_image.resize((new_width, new_height), resample=Image.LANCZOS)
|
174 |
+
|
175 |
+
prep = transforms.Compose([
|
176 |
+
transforms.ToTensor(),
|
177 |
+
transforms.Normalize(mean=self.mean, std=self.std)
|
178 |
+
])
|
179 |
+
prep_img = prep(pil_image)[None, ...]
|
180 |
+
return prep_img, pil_image
|
181 |
+
|
182 |
+
def preprocess_pil(self, pil_image):
|
183 |
+
"""
|
184 |
+
Preprocesses an image before extraction.
|
185 |
+
:param image_path: path to image to be extracted.
|
186 |
+
:param load_size: optional. Size to resize image before the rest of preprocessing.
|
187 |
+
:return: a tuple containing:
|
188 |
+
(1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
|
189 |
+
(2) the pil image in relevant dimensions
|
190 |
+
"""
|
191 |
+
prep = transforms.Compose([
|
192 |
+
transforms.ToTensor(),
|
193 |
+
transforms.Normalize(mean=self.mean, std=self.std)
|
194 |
+
])
|
195 |
+
prep_img = prep(pil_image)[None, ...]
|
196 |
+
return prep_img
|
197 |
+
|
198 |
+
def _get_hook(self, facet: str):
|
199 |
+
"""
|
200 |
+
generate a hook method for a specific block and facet.
|
201 |
+
"""
|
202 |
+
if facet in ['attn', 'token']:
|
203 |
+
def _hook(model, input, output):
|
204 |
+
self._feats.append(output)
|
205 |
+
return _hook
|
206 |
+
|
207 |
+
if facet == 'query':
|
208 |
+
facet_idx = 0
|
209 |
+
elif facet == 'key':
|
210 |
+
facet_idx = 1
|
211 |
+
elif facet == 'value':
|
212 |
+
facet_idx = 2
|
213 |
+
else:
|
214 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
215 |
+
|
216 |
+
def _inner_hook(module, input, output):
|
217 |
+
input = input[0]
|
218 |
+
B, N, C = input.shape
|
219 |
+
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
|
220 |
+
self._feats.append(qkv[facet_idx]) #Bxhxtxd
|
221 |
+
return _inner_hook
|
222 |
+
|
223 |
+
def _register_hooks(self, layers: List[int], facet: str) -> None:
|
224 |
+
"""
|
225 |
+
register hook to extract features.
|
226 |
+
:param layers: layers from which to extract features.
|
227 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
|
228 |
+
"""
|
229 |
+
for block_idx, block in enumerate(self.model.blocks):
|
230 |
+
if block_idx in layers:
|
231 |
+
if facet == 'token':
|
232 |
+
self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
|
233 |
+
elif facet == 'attn':
|
234 |
+
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
|
235 |
+
elif facet in ['key', 'query', 'value']:
|
236 |
+
self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
|
237 |
+
else:
|
238 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
239 |
+
|
240 |
+
def _unregister_hooks(self) -> None:
|
241 |
+
"""
|
242 |
+
unregisters the hooks. should be called after feature extraction.
|
243 |
+
"""
|
244 |
+
for handle in self.hook_handlers:
|
245 |
+
handle.remove()
|
246 |
+
self.hook_handlers = []
|
247 |
+
|
248 |
+
def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]:
|
249 |
+
"""
|
250 |
+
extract features from the model
|
251 |
+
:param batch: batch to extract features for. Has shape BxCxHxW.
|
252 |
+
:param layers: layer to extract. A number between 0 to 11.
|
253 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
|
254 |
+
:return : tensor of features.
|
255 |
+
if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
|
256 |
+
if facet is 'attn' has shape Bxhxtxt
|
257 |
+
if facet is 'token' has shape Bxtxd
|
258 |
+
"""
|
259 |
+
B, C, H, W = batch.shape
|
260 |
+
self._feats = []
|
261 |
+
self._register_hooks(layers, facet)
|
262 |
+
_ = self.model(batch)
|
263 |
+
self._unregister_hooks()
|
264 |
+
self.load_size = (H, W)
|
265 |
+
self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
|
266 |
+
return self._feats
|
267 |
+
|
268 |
+
def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
|
269 |
+
"""
|
270 |
+
create a log-binned descriptor.
|
271 |
+
:param x: tensor of features. Has shape Bxhxtxd. [1,6,3410,64]
|
272 |
+
:param hierarchy: how many bin hierarchies to use.
|
273 |
+
"""
|
274 |
+
B = x.shape[0]
|
275 |
+
num_bins = 1 + 8 * hierarchy
|
276 |
+
|
277 |
+
bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) [1,3410,384]
|
278 |
+
bin_x = bin_x.permute(0, 2, 1)
|
279 |
+
bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1])
|
280 |
+
# Bx(dxh)xnum_patches[0]xnum_patches[1]
|
281 |
+
sub_desc_dim = bin_x.shape[1]
|
282 |
+
|
283 |
+
avg_pools = []
|
284 |
+
# compute bins of all sizes for all spatial locations.
|
285 |
+
for k in range(0, hierarchy):
|
286 |
+
# avg pooling with kernel 3**kx3**k
|
287 |
+
win_size = 3 ** k
|
288 |
+
avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
|
289 |
+
avg_pools.append(avg_pool(bin_x))
|
290 |
+
|
291 |
+
bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device)
|
292 |
+
for y in range(self.num_patches[0]):
|
293 |
+
for x in range(self.num_patches[1]):
|
294 |
+
part_idx = 0
|
295 |
+
# fill all bins for a spatial location (y, x)
|
296 |
+
for k in range(0, hierarchy):
|
297 |
+
kernel_size = 3 ** k
|
298 |
+
for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
|
299 |
+
for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
|
300 |
+
if i == y and j == x and k != 0:
|
301 |
+
continue
|
302 |
+
if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]:
|
303 |
+
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
|
304 |
+
:, :, i, j]
|
305 |
+
else: # handle padding in a more delicate way than zero padding
|
306 |
+
temp_i = max(0, min(i, self.num_patches[0] - 1))
|
307 |
+
temp_j = max(0, min(j, self.num_patches[1] - 1))
|
308 |
+
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
|
309 |
+
:, :, temp_i,
|
310 |
+
temp_j]
|
311 |
+
part_idx += 1
|
312 |
+
bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
|
313 |
+
# Bx1x(t-1)x(dxh)
|
314 |
+
return bin_x #[1,1,3410,6528]
|
315 |
+
|
316 |
+
def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key',
|
317 |
+
bin: bool = False, include_cls: bool = False) -> torch.Tensor:
|
318 |
+
"""
|
319 |
+
extract descriptors from the model
|
320 |
+
:param batch: batch to extract descriptors for. Has shape BxCxHxW.
|
321 |
+
:param layers: layer to extract. A number between 0 to 11.
|
322 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
|
323 |
+
:param bin: apply log binning to the descriptor. default is False.
|
324 |
+
:return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
|
325 |
+
"""
|
326 |
+
assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
|
327 |
+
choose from ['key' | 'query' | 'value' | 'token'] """
|
328 |
+
self._extract_features(batch, [layer], facet)
|
329 |
+
x = self._feats[0]
|
330 |
+
if facet == 'token':
|
331 |
+
x.unsqueeze_(dim=1) #Bx1xtxd
|
332 |
+
if not include_cls:
|
333 |
+
x = x[:, :, 1:, :] # remove cls token
|
334 |
+
else:
|
335 |
+
assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
|
336 |
+
if not bin:
|
337 |
+
desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
338 |
+
else:
|
339 |
+
desc = self._log_bin(x)
|
340 |
+
return desc
|
341 |
+
|
342 |
+
def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
|
343 |
+
"""
|
344 |
+
extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer
|
345 |
+
in of the CLS token. All values are then normalized to range between 0 and 1.
|
346 |
+
:param batch: batch to extract saliency maps for. Has shape BxCxHxW.
|
347 |
+
:return: a tensor of saliency maps. has shape Bxt-1
|
348 |
+
"""
|
349 |
+
assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type."
|
350 |
+
self._extract_features(batch, [11], 'attn')
|
351 |
+
head_idxs = [0, 2, 4, 5]
|
352 |
+
curr_feats = self._feats[0] #Bxhxtxt
|
353 |
+
cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1)
|
354 |
+
temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0]
|
355 |
+
cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1]
|
356 |
+
return cls_attn_maps
|
model_utils/projection_network.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from model_utils.resnet import ResNet, BottleneckBlock
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class DummyAggregationNetwork(nn.Module): # for testing, return the input
|
8 |
+
def __init__(self):
|
9 |
+
super(DummyAggregationNetwork, self).__init__()
|
10 |
+
# dummy paprameter
|
11 |
+
self.dummy = nn.Parameter(torch.ones([]))
|
12 |
+
def forward(self, batch, pose=None):
|
13 |
+
return batch * self.dummy
|
14 |
+
|
15 |
+
class AggregationNetwork(nn.Module):
|
16 |
+
"""
|
17 |
+
Module for aggregating feature maps across time and space.
|
18 |
+
Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
|
19 |
+
https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
|
20 |
+
"""
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
device,
|
24 |
+
feature_dims=[640,1280,1280,768],
|
25 |
+
projection_dim=384,
|
26 |
+
num_norm_groups=32,
|
27 |
+
save_timestep=[1],
|
28 |
+
kernel_size = [1,3,1],
|
29 |
+
contrastive_temp = 10,
|
30 |
+
feat_map_dropout=0.0,
|
31 |
+
num_blocks=None,
|
32 |
+
bottleneck_channels=None
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
self.skip_connection = True
|
36 |
+
self.feat_map_dropout = feat_map_dropout
|
37 |
+
self.azimuth_embedding = None
|
38 |
+
self.pos_embedding = None
|
39 |
+
self.bottleneck_layers = nn.ModuleList()
|
40 |
+
self.feature_dims = feature_dims
|
41 |
+
self.num_blocks = num_blocks if num_blocks is not None else 1
|
42 |
+
self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4
|
43 |
+
# For CLIP symmetric cross entropy loss during training
|
44 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
45 |
+
self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
|
46 |
+
self.device = device
|
47 |
+
self.save_timestep = save_timestep
|
48 |
+
|
49 |
+
self.mixing_weights_names = []
|
50 |
+
for l, feature_dim in enumerate(self.feature_dims):
|
51 |
+
bottleneck_layer = nn.Sequential(
|
52 |
+
*ResNet.make_stage(
|
53 |
+
BottleneckBlock,
|
54 |
+
num_blocks=self.num_blocks,
|
55 |
+
in_channels=feature_dim,
|
56 |
+
bottleneck_channels=self.bottleneck_channels,
|
57 |
+
out_channels=projection_dim,
|
58 |
+
norm="GN",
|
59 |
+
num_norm_groups=num_norm_groups,
|
60 |
+
kernel_size=kernel_size
|
61 |
+
)
|
62 |
+
)
|
63 |
+
self.bottleneck_layers.append(bottleneck_layer)
|
64 |
+
for t in save_timestep:
|
65 |
+
# 1-index the layer name following prior work
|
66 |
+
self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
|
67 |
+
self.last_layer = None
|
68 |
+
self.bottleneck_layers = self.bottleneck_layers.to(device)
|
69 |
+
mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
|
70 |
+
self.mixing_weights = nn.Parameter(mixing_weights.to(device))
|
71 |
+
# count number of parameters
|
72 |
+
num_params = 0
|
73 |
+
for param in self.parameters():
|
74 |
+
num_params += param.numel()
|
75 |
+
print(f"AggregationNetwork has {num_params} parameters.")
|
76 |
+
|
77 |
+
def load_pretrained_weights(self, pretrained_dict):
|
78 |
+
custom_dict = self.state_dict()
|
79 |
+
|
80 |
+
# Handle size mismatch
|
81 |
+
if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
|
82 |
+
# Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
|
83 |
+
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
|
84 |
+
custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
|
85 |
+
else:
|
86 |
+
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
|
87 |
+
|
88 |
+
# Load the weights that do match
|
89 |
+
matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
|
90 |
+
custom_dict.update(matching_keys)
|
91 |
+
|
92 |
+
# Now load the updated state_dict
|
93 |
+
self.load_state_dict(custom_dict, strict=False)
|
94 |
+
|
95 |
+
def forward(self, batch, pose=None):
|
96 |
+
"""
|
97 |
+
Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
|
98 |
+
"""
|
99 |
+
if self.feat_map_dropout > 0 and self.training:
|
100 |
+
batch = F.dropout(batch, p=self.feat_map_dropout)
|
101 |
+
|
102 |
+
output_feature = None
|
103 |
+
start = 0
|
104 |
+
mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
|
105 |
+
if self.pos_embedding is not None: #position embedding
|
106 |
+
batch = torch.cat((batch, self.pos_embedding), dim=1)
|
107 |
+
for i in range(len(mixing_weights)):
|
108 |
+
# Share bottleneck layers across timesteps
|
109 |
+
bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
|
110 |
+
# Chunk the batch according the layer
|
111 |
+
# Account for looping if there are multiple timesteps
|
112 |
+
end = start + self.feature_dims[i % len(self.feature_dims)]
|
113 |
+
feats = batch[:, start:end, :, :]
|
114 |
+
start = end
|
115 |
+
# Downsample the number of channels and weight the layer
|
116 |
+
bottlenecked_feature = bottleneck_layer(feats)
|
117 |
+
bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
|
118 |
+
if output_feature is None:
|
119 |
+
output_feature = bottlenecked_feature
|
120 |
+
else:
|
121 |
+
output_feature += bottlenecked_feature
|
122 |
+
|
123 |
+
if self.last_layer is not None:
|
124 |
+
|
125 |
+
output_feature_after = self.last_layer(output_feature)
|
126 |
+
if self.skip_connection:
|
127 |
+
# skip connection
|
128 |
+
output_feature = output_feature + output_feature_after
|
129 |
+
return output_feature
|
130 |
+
|
131 |
+
|
132 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
133 |
+
"""1x1 convolution without padding"""
|
134 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
135 |
+
|
136 |
+
|
137 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
138 |
+
"""3x3 convolution with padding"""
|
139 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
140 |
+
|
141 |
+
|
142 |
+
class BasicBlock(nn.Module):
|
143 |
+
def __init__(self, in_planes, planes, stride=1):
|
144 |
+
super().__init__()
|
145 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
146 |
+
self.conv2 = conv3x3(planes, planes)
|
147 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
148 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
149 |
+
self.relu = nn.ReLU(inplace=True)
|
150 |
+
|
151 |
+
if stride == 1:
|
152 |
+
self.downsample = None
|
153 |
+
else:
|
154 |
+
self.downsample = nn.Sequential(
|
155 |
+
conv1x1(in_planes, planes, stride=stride),
|
156 |
+
nn.BatchNorm2d(planes)
|
157 |
+
)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
y = x
|
161 |
+
y = self.relu(self.bn1(self.conv1(y)))
|
162 |
+
y = self.bn2(self.conv2(y))
|
163 |
+
|
164 |
+
if self.downsample is not None:
|
165 |
+
x = self.downsample(x)
|
166 |
+
|
167 |
+
return self.relu(x+y)
|
model_utils/resnet.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
# import fvcore.nn.weight_init as weight_init
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
"""
|
8 |
+
Functions for building the BottleneckBlock from Detectron2.
|
9 |
+
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
|
10 |
+
"""
|
11 |
+
|
12 |
+
def get_norm(norm, out_channels, num_norm_groups=32):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
16 |
+
or a callable that takes a channel number and returns
|
17 |
+
the normalization layer as a nn.Module.
|
18 |
+
Returns:
|
19 |
+
nn.Module or None: the normalization layer
|
20 |
+
"""
|
21 |
+
if norm is None:
|
22 |
+
return None
|
23 |
+
if isinstance(norm, str):
|
24 |
+
if len(norm) == 0:
|
25 |
+
return None
|
26 |
+
norm = {
|
27 |
+
"GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
|
28 |
+
}[norm]
|
29 |
+
return norm(out_channels)
|
30 |
+
|
31 |
+
def get_activation(activation):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
activation (str or callable): either one of relu, lrelu, prelu, leaky_relu,
|
35 |
+
sigmoid, tanh, elu, selu, swish, mish; or a callable that takes a
|
36 |
+
tensor and returns a tensor.
|
37 |
+
Returns:
|
38 |
+
nn.Module or None: the activation layer
|
39 |
+
"""
|
40 |
+
if activation is None:
|
41 |
+
return None
|
42 |
+
if isinstance(activation, str):
|
43 |
+
if len(activation) == 0:
|
44 |
+
return None
|
45 |
+
activation = {
|
46 |
+
"relu": nn.ReLU,
|
47 |
+
"lrelu": nn.LeakyReLU,
|
48 |
+
"prelu": nn.PReLU,
|
49 |
+
"leaky_relu": nn.LeakyReLU,
|
50 |
+
"sigmoid": nn.Sigmoid,
|
51 |
+
"tanh": nn.Tanh,
|
52 |
+
"elu": nn.ELU,
|
53 |
+
"selu": nn.SELU,
|
54 |
+
}[activation]
|
55 |
+
return activation()
|
56 |
+
|
57 |
+
# SCE crisscross + diags
|
58 |
+
class EfficientSpatialContextNet(nn.Module):
|
59 |
+
def __init__(self, kernel_size=7, in_channels=768, out_channels=768, use_cuda=True):
|
60 |
+
super(EfficientSpatialContextNet, self).__init__()
|
61 |
+
self.kernel_size = kernel_size
|
62 |
+
self.pad = kernel_size // 2
|
63 |
+
self.conv = torch.nn.Conv2d(
|
64 |
+
in_channels + 4*self.kernel_size,
|
65 |
+
out_channels,
|
66 |
+
1,
|
67 |
+
bias=True,
|
68 |
+
padding_mode="zeros",
|
69 |
+
)
|
70 |
+
|
71 |
+
if use_cuda:
|
72 |
+
self.conv = self.conv.cuda()
|
73 |
+
|
74 |
+
def forward(self, feature):
|
75 |
+
b, c, h, w = feature.size()
|
76 |
+
feature_normalized = F.normalize(feature, p=2, dim=1)
|
77 |
+
feature_pad = F.pad(
|
78 |
+
feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0
|
79 |
+
)
|
80 |
+
output = torch.zeros(
|
81 |
+
[4*self.kernel_size, b, h, w],
|
82 |
+
dtype=feature.dtype,
|
83 |
+
requires_grad=feature.requires_grad,
|
84 |
+
)
|
85 |
+
if feature.is_cuda:
|
86 |
+
output = output.cuda(feature.get_device())
|
87 |
+
|
88 |
+
# left-top to right-bottom
|
89 |
+
for i in range(self.kernel_size):
|
90 |
+
c=i
|
91 |
+
r=i
|
92 |
+
output[i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
|
93 |
+
|
94 |
+
# col
|
95 |
+
for i in range(self.kernel_size):
|
96 |
+
c=self.kernel_size//2
|
97 |
+
r=i
|
98 |
+
output[1*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
|
99 |
+
|
100 |
+
# right-top to left-bottom
|
101 |
+
for i in range(self.kernel_size):
|
102 |
+
c=(self.kernel_size-1)-i
|
103 |
+
r=i
|
104 |
+
output[2*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
|
105 |
+
|
106 |
+
# row
|
107 |
+
for i in range(self.kernel_size):
|
108 |
+
c=i
|
109 |
+
r=self.kernel_size//2
|
110 |
+
output[3*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
|
111 |
+
|
112 |
+
output = output.transpose(0, 1).contiguous()
|
113 |
+
output = torch.cat((feature, output), 1)
|
114 |
+
output = self.conv(output)
|
115 |
+
# output = F.relu(output)
|
116 |
+
|
117 |
+
return output
|
118 |
+
|
119 |
+
def c2_msra_fill(module: nn.Module) -> None:
|
120 |
+
"""
|
121 |
+
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
|
122 |
+
Also initializes `module.bias` to 0.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
module (torch.nn.Module): module to initialize.
|
126 |
+
"""
|
127 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
128 |
+
if module.bias is not None:
|
129 |
+
nn.init.constant_(module.bias, 0)
|
130 |
+
|
131 |
+
class Conv2d(nn.Conv2d):
|
132 |
+
"""
|
133 |
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, *args, **kwargs):
|
137 |
+
"""
|
138 |
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
139 |
+
Args:
|
140 |
+
norm (nn.Module, optional): a normalization layer
|
141 |
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
142 |
+
It assumes that norm layer is used before activation.
|
143 |
+
"""
|
144 |
+
norm = kwargs.pop("norm", None)
|
145 |
+
activation = kwargs.pop("activation", None)
|
146 |
+
super().__init__(*args, **kwargs)
|
147 |
+
|
148 |
+
self.norm = norm
|
149 |
+
self.activation = activation
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
x = F.conv2d(
|
153 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
154 |
+
)
|
155 |
+
if self.norm is not None:
|
156 |
+
x = self.norm(x)
|
157 |
+
if self.activation is not None:
|
158 |
+
x = self.activation(x)
|
159 |
+
return x
|
160 |
+
|
161 |
+
class CNNBlockBase(nn.Module):
|
162 |
+
"""
|
163 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
164 |
+
The input and output of `forward()` method must be NCHW tensors.
|
165 |
+
The method can perform arbitrary computation but must match the given
|
166 |
+
channels and stride specification.
|
167 |
+
Attribute:
|
168 |
+
in_channels (int):
|
169 |
+
out_channels (int):
|
170 |
+
stride (int):
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, in_channels, out_channels, stride):
|
174 |
+
"""
|
175 |
+
The `__init__` method of any subclass should also contain these arguments.
|
176 |
+
Args:
|
177 |
+
in_channels (int):
|
178 |
+
out_channels (int):
|
179 |
+
stride (int):
|
180 |
+
"""
|
181 |
+
super().__init__()
|
182 |
+
self.in_channels = in_channels
|
183 |
+
self.out_channels = out_channels
|
184 |
+
self.stride = stride
|
185 |
+
|
186 |
+
class BottleneckBlock(CNNBlockBase):
|
187 |
+
"""
|
188 |
+
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
189 |
+
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
190 |
+
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
in_channels,
|
196 |
+
out_channels,
|
197 |
+
*,
|
198 |
+
bottleneck_channels,
|
199 |
+
stride=1,
|
200 |
+
num_groups=1,
|
201 |
+
norm="GN",
|
202 |
+
stride_in_1x1=False,
|
203 |
+
dilation=1,
|
204 |
+
num_norm_groups=32,
|
205 |
+
kernel_size = (1,3,1)
|
206 |
+
):
|
207 |
+
"""
|
208 |
+
Args:
|
209 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
210 |
+
"bottleneck" conv layers.
|
211 |
+
num_groups (int): number of groups for the 3x3 conv layer.
|
212 |
+
norm (str or callable): normalization for all conv layers.
|
213 |
+
See :func:`layers.get_norm` for supported format.
|
214 |
+
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
215 |
+
first 1x1 convolution or the bottleneck 3x3 convolution.
|
216 |
+
dilation (int): the dilation rate of the 3x3 conv layer.
|
217 |
+
"""
|
218 |
+
super().__init__(in_channels, out_channels, stride)
|
219 |
+
|
220 |
+
if in_channels != out_channels:
|
221 |
+
self.shortcut = Conv2d(
|
222 |
+
in_channels,
|
223 |
+
out_channels,
|
224 |
+
kernel_size=1,
|
225 |
+
stride=stride,
|
226 |
+
bias=False,
|
227 |
+
norm=get_norm(norm, out_channels, num_norm_groups),
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
self.shortcut = None
|
231 |
+
|
232 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
233 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
234 |
+
# stride in the 3x3 conv
|
235 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
236 |
+
|
237 |
+
self.conv1 = Conv2d(
|
238 |
+
in_channels,
|
239 |
+
bottleneck_channels,
|
240 |
+
kernel_size=kernel_size[0],
|
241 |
+
stride=stride_1x1,
|
242 |
+
padding=(kernel_size[0]-1)//2,
|
243 |
+
bias=False,
|
244 |
+
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
245 |
+
)
|
246 |
+
|
247 |
+
self.conv2 = Conv2d(
|
248 |
+
bottleneck_channels,
|
249 |
+
bottleneck_channels,
|
250 |
+
kernel_size=kernel_size[1],
|
251 |
+
stride=stride_3x3,
|
252 |
+
padding=dilation*(kernel_size[1]-1)//2,
|
253 |
+
bias=False,
|
254 |
+
groups=num_groups,
|
255 |
+
dilation=dilation,
|
256 |
+
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
257 |
+
)
|
258 |
+
|
259 |
+
self.conv3 = Conv2d(
|
260 |
+
bottleneck_channels,
|
261 |
+
out_channels,
|
262 |
+
kernel_size=kernel_size[2],
|
263 |
+
bias=False,
|
264 |
+
norm=get_norm(norm, out_channels, num_norm_groups),
|
265 |
+
)
|
266 |
+
|
267 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
268 |
+
if layer is not None: # shortcut can be None
|
269 |
+
c2_msra_fill(layer)
|
270 |
+
|
271 |
+
# Zero-initialize the last normalization in each residual branch,
|
272 |
+
# so that at the beginning, the residual branch starts with zeros,
|
273 |
+
# and each residual block behaves like an identity.
|
274 |
+
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
275 |
+
# "For BN layers, the learnable scaling coefficient γ is initialized
|
276 |
+
# to be 1, except for each residual block's last BN
|
277 |
+
# where γ is initialized to be 0."
|
278 |
+
|
279 |
+
# nn.init.constant_(self.conv3.norm.weight, 0)
|
280 |
+
# TODO this somehow hurts performance when training GN models from scratch.
|
281 |
+
# Add it as an option when we need to use this code to train a backbone.
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
out = self.conv1(x)
|
285 |
+
out = F.relu_(out)
|
286 |
+
|
287 |
+
out = self.conv2(out)
|
288 |
+
out = F.relu_(out)
|
289 |
+
|
290 |
+
out = self.conv3(out)
|
291 |
+
|
292 |
+
if self.shortcut is not None:
|
293 |
+
shortcut = self.shortcut(x)
|
294 |
+
else:
|
295 |
+
shortcut = x
|
296 |
+
|
297 |
+
out += shortcut
|
298 |
+
out = F.relu_(out)
|
299 |
+
return out
|
300 |
+
|
301 |
+
class ResNet(nn.Module):
|
302 |
+
"""
|
303 |
+
Implement :paper:`ResNet`.
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
|
307 |
+
"""
|
308 |
+
Args:
|
309 |
+
stem (nn.Module): a stem module
|
310 |
+
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
|
311 |
+
each contains multiple :class:`CNNBlockBase`.
|
312 |
+
num_classes (None or int): if None, will not perform classification.
|
313 |
+
Otherwise, will create a linear layer.
|
314 |
+
out_features (list[str]): name of the layers whose outputs should
|
315 |
+
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
|
316 |
+
If None, will return the output of the last layer.
|
317 |
+
freeze_at (int): The number of stages at the beginning to freeze.
|
318 |
+
see :meth:`freeze` for detailed explanation.
|
319 |
+
"""
|
320 |
+
super().__init__()
|
321 |
+
self.stem = stem
|
322 |
+
self.num_classes = num_classes
|
323 |
+
|
324 |
+
current_stride = self.stem.stride
|
325 |
+
self._out_feature_strides = {"stem": current_stride}
|
326 |
+
self._out_feature_channels = {"stem": self.stem.out_channels}
|
327 |
+
|
328 |
+
self.stage_names, self.stages = [], []
|
329 |
+
|
330 |
+
if out_features is not None:
|
331 |
+
# Avoid keeping unused layers in this module. They consume extra memory
|
332 |
+
# and may cause allreduce to fail
|
333 |
+
num_stages = max(
|
334 |
+
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
|
335 |
+
)
|
336 |
+
stages = stages[:num_stages]
|
337 |
+
for i, blocks in enumerate(stages):
|
338 |
+
assert len(blocks) > 0, len(blocks)
|
339 |
+
for block in blocks:
|
340 |
+
assert isinstance(block, CNNBlockBase), block
|
341 |
+
|
342 |
+
name = "res" + str(i + 2)
|
343 |
+
stage = nn.Sequential(*blocks)
|
344 |
+
|
345 |
+
self.add_module(name, stage)
|
346 |
+
self.stage_names.append(name)
|
347 |
+
self.stages.append(stage)
|
348 |
+
|
349 |
+
self._out_feature_strides[name] = current_stride = int(
|
350 |
+
current_stride * np.prod([k.stride for k in blocks])
|
351 |
+
)
|
352 |
+
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
|
353 |
+
self.stage_names = tuple(self.stage_names) # Make it static for scripting
|
354 |
+
|
355 |
+
if num_classes is not None:
|
356 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
357 |
+
self.linear = nn.Linear(curr_channels, num_classes)
|
358 |
+
|
359 |
+
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
360 |
+
# "The 1000-way fully-connected layer is initialized by
|
361 |
+
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
|
362 |
+
nn.init.normal_(self.linear.weight, std=0.01)
|
363 |
+
name = "linear"
|
364 |
+
|
365 |
+
if out_features is None:
|
366 |
+
out_features = [name]
|
367 |
+
self._out_features = out_features
|
368 |
+
assert len(self._out_features)
|
369 |
+
children = [x[0] for x in self.named_children()]
|
370 |
+
for out_feature in self._out_features:
|
371 |
+
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
372 |
+
self.freeze(freeze_at)
|
373 |
+
|
374 |
+
def forward(self, x):
|
375 |
+
"""
|
376 |
+
Args:
|
377 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
378 |
+
Returns:
|
379 |
+
dict[str->Tensor]: names and the corresponding features
|
380 |
+
"""
|
381 |
+
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
382 |
+
outputs = {}
|
383 |
+
x = self.stem(x)
|
384 |
+
if "stem" in self._out_features:
|
385 |
+
outputs["stem"] = x
|
386 |
+
for name, stage in zip(self.stage_names, self.stages):
|
387 |
+
x = stage(x)
|
388 |
+
if name in self._out_features:
|
389 |
+
outputs[name] = x
|
390 |
+
if self.num_classes is not None:
|
391 |
+
x = self.avgpool(x)
|
392 |
+
x = torch.flatten(x, 1)
|
393 |
+
x = self.linear(x)
|
394 |
+
if "linear" in self._out_features:
|
395 |
+
outputs["linear"] = x
|
396 |
+
return outputs
|
397 |
+
|
398 |
+
def freeze(self, freeze_at=0):
|
399 |
+
"""
|
400 |
+
Freeze the first several stages of the ResNet. Commonly used in
|
401 |
+
fine-tuning.
|
402 |
+
Layers that produce the same feature map spatial size are defined as one
|
403 |
+
"stage" by :paper:`FPN`.
|
404 |
+
Args:
|
405 |
+
freeze_at (int): number of stages to freeze.
|
406 |
+
`1` means freezing the stem. `2` means freezing the stem and
|
407 |
+
one residual stage, etc.
|
408 |
+
Returns:
|
409 |
+
nn.Module: this ResNet itself
|
410 |
+
"""
|
411 |
+
if freeze_at >= 1:
|
412 |
+
self.stem.freeze()
|
413 |
+
for idx, stage in enumerate(self.stages, start=2):
|
414 |
+
if freeze_at >= idx:
|
415 |
+
for block in stage.children():
|
416 |
+
block.freeze()
|
417 |
+
return self
|
418 |
+
|
419 |
+
@staticmethod
|
420 |
+
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
|
421 |
+
"""
|
422 |
+
Create a list of blocks of the same type that forms one ResNet stage.
|
423 |
+
Args:
|
424 |
+
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
|
425 |
+
stage. A module of this type must not change spatial resolution of inputs unless its
|
426 |
+
stride != 1.
|
427 |
+
num_blocks (int): number of blocks in this stage
|
428 |
+
in_channels (int): input channels of the entire stage.
|
429 |
+
out_channels (int): output channels of **every block** in the stage.
|
430 |
+
kwargs: other arguments passed to the constructor of
|
431 |
+
`block_class`. If the argument name is "xx_per_block", the
|
432 |
+
argument is a list of values to be passed to each block in the
|
433 |
+
stage. Otherwise, the same argument is passed to every block
|
434 |
+
in the stage.
|
435 |
+
Returns:
|
436 |
+
list[CNNBlockBase]: a list of block module.
|
437 |
+
Examples:
|
438 |
+
::
|
439 |
+
stage = ResNet.make_stage(
|
440 |
+
BottleneckBlock, 3, in_channels=16, out_channels=64,
|
441 |
+
bottleneck_channels=16, num_groups=1,
|
442 |
+
stride_per_block=[2, 1, 1],
|
443 |
+
dilations_per_block=[1, 1, 2]
|
444 |
+
)
|
445 |
+
Usually, layers that produce the same feature map spatial size are defined as one
|
446 |
+
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
|
447 |
+
all be 1.
|
448 |
+
"""
|
449 |
+
blocks = []
|
450 |
+
for i in range(num_blocks):
|
451 |
+
curr_kwargs = {}
|
452 |
+
for k, v in kwargs.items():
|
453 |
+
if k.endswith("_per_block"):
|
454 |
+
assert len(v) == num_blocks, (
|
455 |
+
f"Argument '{k}' of make_stage should have the "
|
456 |
+
f"same length as num_blocks={num_blocks}."
|
457 |
+
)
|
458 |
+
newk = k[: -len("_per_block")]
|
459 |
+
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
460 |
+
curr_kwargs[newk] = v[i]
|
461 |
+
else:
|
462 |
+
curr_kwargs[k] = v
|
463 |
+
|
464 |
+
blocks.append(
|
465 |
+
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
|
466 |
+
)
|
467 |
+
in_channels = out_channels
|
468 |
+
return blocks
|
469 |
+
|
470 |
+
@staticmethod
|
471 |
+
def make_default_stages(depth, block_class=None, **kwargs):
|
472 |
+
"""
|
473 |
+
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
|
474 |
+
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
|
475 |
+
instead for fine-grained customization.
|
476 |
+
Args:
|
477 |
+
depth (int): depth of ResNet
|
478 |
+
block_class (type): the CNN block class. Has to accept
|
479 |
+
`bottleneck_channels` argument for depth > 50.
|
480 |
+
By default it is BasicBlock or BottleneckBlock, based on the
|
481 |
+
depth.
|
482 |
+
kwargs:
|
483 |
+
other arguments to pass to `make_stage`. Should not contain
|
484 |
+
stride and channels, as they are predefined for each depth.
|
485 |
+
Returns:
|
486 |
+
list[list[CNNBlockBase]]: modules in all stages; see arguments of
|
487 |
+
:class:`ResNet.__init__`.
|
488 |
+
"""
|
489 |
+
num_blocks_per_stage = {
|
490 |
+
18: [2, 2, 2, 2],
|
491 |
+
34: [3, 4, 6, 3],
|
492 |
+
50: [3, 4, 6, 3],
|
493 |
+
101: [3, 4, 23, 3],
|
494 |
+
152: [3, 8, 36, 3],
|
495 |
+
}[depth]
|
496 |
+
if block_class is None:
|
497 |
+
block_class = BasicBlock if depth < 50 else BottleneckBlock
|
498 |
+
if depth < 50:
|
499 |
+
in_channels = [64, 64, 128, 256]
|
500 |
+
out_channels = [64, 128, 256, 512]
|
501 |
+
else:
|
502 |
+
in_channels = [64, 256, 512, 1024]
|
503 |
+
out_channels = [256, 512, 1024, 2048]
|
504 |
+
ret = []
|
505 |
+
for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
|
506 |
+
if depth >= 50:
|
507 |
+
kwargs["bottleneck_channels"] = o // 4
|
508 |
+
ret.append(
|
509 |
+
ResNet.make_stage(
|
510 |
+
block_class=block_class,
|
511 |
+
num_blocks=n,
|
512 |
+
stride_per_block=[s] + [1] * (n - 1),
|
513 |
+
in_channels=i,
|
514 |
+
out_channels=o,
|
515 |
+
**kwargs,
|
516 |
+
)
|
517 |
+
)
|
518 |
+
return ret
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
pillow
|
5 |
+
gradio
|
6 |
+
matplotlib
|