odunkel commited on
Commit
79cc514
·
verified ·
1 Parent(s): f3b59a6

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from PIL import Image, ImageDraw
8
+ from matplotlib import cm
9
+
10
+ from model_utils.extractor_dino import ViTExtractor
11
+ from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork
12
+
13
+ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'):
14
+ filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST
15
+ original_width, original_height = img.size
16
+ original_channels = len(img.getbands())
17
+ if not edge:
18
+ canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8)
19
+ if original_channels == 1:
20
+ canvas = np.zeros([target_res, target_res], dtype=np.uint8)
21
+ if original_height <= original_width:
22
+ if resize:
23
+ img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
24
+ width, height = img.size
25
+ img = np.asarray(img)
26
+ canvas[(width - height) // 2: (width + height) // 2] = img
27
+ else:
28
+ if resize:
29
+ img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
30
+ width, height = img.size
31
+ img = np.asarray(img)
32
+ canvas[:, (height - width) // 2: (height + width) // 2] = img
33
+ else:
34
+ if original_height <= original_width:
35
+ if resize:
36
+ img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
37
+ width, height = img.size
38
+ img = np.asarray(img)
39
+ top_pad = (target_res - height) // 2
40
+ bottom_pad = target_res - height - top_pad
41
+ img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge')
42
+ else:
43
+ if resize:
44
+ img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
45
+ width, height = img.size
46
+ img = np.asarray(img)
47
+ left_pad = (target_res - width) // 2
48
+ right_pad = target_res - width - left_pad
49
+ img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge')
50
+ canvas = img
51
+ if to_pil:
52
+ canvas = Image.fromarray(canvas)
53
+ return canvas
54
+
55
+ # ─── Configuration ───────────────────────────────────────────────
56
+ num_patches = 30
57
+ target_res = num_patches * 14
58
+ ckpt_file = "ckpts/dino_spair_0300.pth"
59
+
60
+ # ─── Model setup ─────────────────────────────────────────────────
61
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
62
+ aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
63
+ aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
64
+ aggre_net_dummy = DummyAggregationNetwork()
65
+ extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
66
+
67
+ # ─── Feature extraction ──────────────────────────────────────────
68
+ def get_processed_features_dino(num_patches, img,use_dummy):
69
+ batch = extractor_vit.preprocess_pil(img)
70
+ features_dino = extractor_vit.extract_descriptors(batch.cuda(), layer=11, facet='token') \
71
+ .permute(0,1,3,2) \
72
+ .reshape(1, -1, num_patches, num_patches)
73
+ # Project + normalize
74
+ with torch.no_grad():
75
+ if use_dummy == "DINOv2":
76
+ desc = aggre_net_dummy(features_dino)
77
+ else:
78
+ desc = aggre_net(features_dino)
79
+ norms = torch.linalg.norm(desc, dim=1, keepdim=True)
80
+ desc = desc / (norms + 1e-8)
81
+ return desc # shape [1, C, num_patches, num_patches]
82
+
83
+ # ─── Similarity computation ───────────────────────────────────────
84
+ def get_sim(
85
+ coord: tuple[int,int],
86
+ feat1: torch.Tensor,
87
+ feat2: torch.Tensor,
88
+ img_size: int = target_res
89
+ ) -> np.ndarray:
90
+ """
91
+ Upsamples the DINO features to `img_size`, then computes cosine‐similarity
92
+ between the feature at `coord` in source and every spatial location in target.
93
+ """
94
+ y, x = coord # row, col
95
+
96
+ # Upsample both feature maps to [1, C, img_size, img_size]
97
+ upsampler = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=False)
98
+ src_ft = upsampler(feat1) # [1, C, img_size, img_size]
99
+ trg_ft = upsampler(feat2)
100
+
101
+ # Extract the C‐dim vector at the clicked location
102
+ C = src_ft.size(1)
103
+ src_vec = src_ft[0, :, y, x].view(1, C, 1, 1) # [1, C, 1, 1]
104
+
105
+ # Cosine similarity along channel‐dim
106
+ cos = nn.CosineSimilarity(dim=1)
107
+ cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size]
108
+ return cos_map.cpu().numpy()
109
+
110
+ # ─── Drawing helper ───────────────────────────────────────────────
111
+ def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
112
+ pil = Image.fromarray(img_arr)
113
+ draw = ImageDraw.Draw(pil)
114
+ r = size // 2
115
+ draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color)
116
+ return np.array(pil)
117
+
118
+ # ─── Feature‐updating callback ───────────────────────────────────
119
+ def update_features(
120
+ img: Image,
121
+ num_patches,
122
+ use_dummy
123
+ ):
124
+ """
125
+ Given a PIL image, returns:
126
+ 1) the same PIL image (so it can be displayed)
127
+ 2) its DINO descriptor tensor, stored in a gr.State
128
+ """
129
+ if img is None:
130
+ return None, None, None
131
+ img = resize(img, target_res=target_res, resize=True, to_pil=True)
132
+ feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
133
+ return img, feat.cpu(), Image.fromarray(np.array(img))
134
+
135
+ # ─── Click handler ───────────────────────────────────────────────
136
+ def on_select(
137
+ source_pil: Image,
138
+ target_pil: Image,
139
+ feat1: torch.Tensor,
140
+ feat2: torch.Tensor,
141
+ alpha: float,
142
+ scatter_size: int,
143
+ or_tgt_img: Image,
144
+ or_src_img: Image,
145
+ sel: gr.SelectData
146
+ ):
147
+ # Convert to numpy arrays
148
+ src_arr = np.array(or_src_img)
149
+ tgt_arr = np.array(or_tgt_img)
150
+
151
+ # Get click coords (row, col)
152
+ y, x = sel.index
153
+
154
+ src_marked = draw_point(src_arr, y, x, scatter_size)
155
+
156
+ # Compute similarity map
157
+ sim_map = get_sim((x, y), feat1, feat2, img_size=target_res)
158
+
159
+ mn, mx = sim_map.min(), sim_map.max()
160
+ sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12)
161
+
162
+ # Build RGBA heatmap
163
+ heat = cm.viridis(sim_norm) # H×W×4
164
+ heat[..., 3] = sim_norm * alpha # alpha channel
165
+
166
+ # Composite over fresh target
167
+ tgt_f = tgt_arr.astype(np.float32) / 255.0
168
+ comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4])
169
+ overlay = (comp * 255).astype(np.uint8)
170
+
171
+ # Draw a red dot at the best match
172
+ my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape)
173
+ overlay_marked = draw_point(overlay, mx_, my, scatter_size)
174
+
175
+ return src_marked,overlay_marked
176
+
177
+ def reload_img(
178
+ or_src_img: Image,
179
+ or_tgt_img: Image,
180
+ ):
181
+ return or_src_img,or_tgt_img
182
+
183
+
184
+
185
+ # ─── Build Gradio UI ──────────────────────────────────────────────
186
+ with gr.Blocks() as demo:
187
+ # Hidden states to hold features
188
+ feat1_state = gr.State()
189
+ feat2_state = gr.State()
190
+ or_tgt_img = gr.State()
191
+ or_src_img = gr.State()
192
+
193
+ # Introduction text box
194
+ intro_text = gr.Markdown("""
195
+ ## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
196
+ [Project Page](https://example.com) | [GitHub Repository](https://github.com/example/repo)
197
+
198
+ Welcome to the DIY-SC demo!
199
+ Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
200
+ You can choose between the DIY-SC (DINOv2) and a DINOv2 feature extractor.
201
+ """)
202
+
203
+ # Image upload / display components
204
+ with gr.Row():
205
+ src = gr.Image(interactive=True, type="pil", label="Source Image")
206
+ tgt = gr.Image(interactive=True, type="pil", label="Target Image")
207
+
208
+ # Controls
209
+ alpha = gr.State(0.7)
210
+ scatter = gr.State(10)
211
+ use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor")
212
+
213
+ src.input(
214
+ fn=update_features,
215
+ inputs=[src, gr.State(num_patches),use_dummy],
216
+ outputs=[src, feat1_state,or_src_img,]
217
+ )
218
+
219
+ tgt.input(
220
+ fn=update_features,
221
+ inputs=[tgt, gr.State(num_patches),use_dummy],
222
+ outputs=[tgt, feat2_state,or_tgt_img]
223
+ )
224
+
225
+ use_dummy.change(
226
+ fn=update_features,
227
+ inputs=[or_src_img, gr.State(num_patches), use_dummy],
228
+ outputs=[src, feat1_state, or_src_img]
229
+ )
230
+
231
+ use_dummy.change(
232
+ fn=update_features,
233
+ inputs=[or_tgt_img, gr.State(num_patches), use_dummy],
234
+ outputs=[tgt, feat2_state, or_tgt_img]
235
+ )
236
+
237
+ src.select(
238
+ fn=on_select,
239
+ inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img],
240
+ outputs=[src,tgt]
241
+ )
242
+
243
+ demo.launch(share=True)
ckpts/dino_spair_0300.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42e9c8a4d27041af7bb8dc5240ea3e83d47e013b72541dc8c5829368f247e705
3
+ size 2519767
model_utils/__pycache__/extractor_dino.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
model_utils/__pycache__/projection_network.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
model_utils/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
model_utils/extractor_dino.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import transforms
4
+ import torch.nn.modules.utils as nn_utils
5
+ import math
6
+ import types
7
+ from pathlib import Path
8
+ from typing import Union, List, Tuple
9
+ from PIL import Image
10
+
11
+
12
+ class ViTExtractor:
13
+ """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
14
+ We use the following notation in the documentation of the module's methods:
15
+ B - batch size
16
+ h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
17
+ p - patch size of the ViT. either 8 or 16.
18
+ t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
19
+ of the input image.
20
+ d - the embedding dimension in the ViT.
21
+ """
22
+
23
+ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
24
+ """
25
+ :param model_type: A string specifying the type of model to extract from.
26
+ [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
27
+ vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
28
+ :param stride: stride of first convolution layer. small stride -> higher resolution.
29
+ :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
30
+ should be compatible with model_type.
31
+ """
32
+ self.model_type = model_type
33
+ self.device = device
34
+ if model is not None:
35
+ self.model = model
36
+ else:
37
+ self.model = ViTExtractor.create_model(model_type)
38
+
39
+ self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
40
+ self.model.eval()
41
+ self.model.to(self.device)
42
+ self.p = self.model.patch_embed.patch_size
43
+ if type(self.p)==tuple:
44
+ self.p = self.p[0]
45
+ self.stride = self.model.patch_embed.proj.stride
46
+
47
+ self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
48
+ self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5)
49
+
50
+ self._feats = []
51
+ self.hook_handlers = []
52
+ self.load_size = None
53
+ self.num_patches = None
54
+
55
+ @staticmethod
56
+ def create_model(model_type: str) -> nn.Module:
57
+ """
58
+ :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 |
59
+ dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 |
60
+ vit_base_patch16_224]
61
+ :return: the model
62
+ """
63
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
64
+ if 'v2' in model_type:
65
+ model = torch.hub.load('facebookresearch/dinov2', model_type)
66
+ elif 'dino' in model_type:
67
+ model = torch.hub.load('facebookresearch/dino:main', model_type)
68
+ elif 'ibot' in model_type:
69
+ model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
70
+ temp_state_dict = torch.load("ibot/checkpoint_teacher.pth", map_location="cpu")
71
+ temp_state_dict = temp_state_dict["state_dict"]
72
+ # remove `module.` prefix
73
+ temp_state_dict = {k.replace("module.", ""): v for k, v in temp_state_dict.items()}
74
+ # remove `backbone.` prefix induced by multicrop wrapper
75
+ temp_state_dict = {k.replace("backbone.", ""): v for k, v in temp_state_dict.items()}
76
+ msg = model.load_state_dict(temp_state_dict, strict=False)
77
+ print(msg)
78
+ else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
79
+ import timm
80
+ temp_model = timm.create_model(model_type, pretrained=True)
81
+ model_type_dict = {
82
+ 'vit_small_patch16_224': 'dino_vits16',
83
+ 'vit_small_patch8_224': 'dino_vits8',
84
+ 'vit_base_patch16_224': 'dino_vitb16',
85
+ 'vit_base_patch8_224': 'dino_vitb8'
86
+ }
87
+ model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type])
88
+ temp_state_dict = temp_model.state_dict()
89
+ del temp_state_dict['head.weight']
90
+ del temp_state_dict['head.bias']
91
+ model.load_state_dict(temp_state_dict)
92
+ return model
93
+
94
+ @staticmethod
95
+ def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]):
96
+ """
97
+ Creates a method for position encoding interpolation.
98
+ :param patch_size: patch size of the model.
99
+ :param stride_hw: A tuple containing the new height and width stride respectively.
100
+ :return: the interpolation method
101
+ """
102
+ def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
103
+ npatch = x.shape[1] - 1
104
+ N = self.pos_embed.shape[1] - 1
105
+ if npatch == N and w == h:
106
+ return self.pos_embed
107
+ class_pos_embed = self.pos_embed[:, 0]
108
+ patch_pos_embed = self.pos_embed[:, 1:]
109
+ dim = x.shape[-1]
110
+ # compute number of tokens taking stride into account
111
+ w0 = 1 + (w - patch_size) // stride_hw[1]
112
+ h0 = 1 + (h - patch_size) // stride_hw[0]
113
+ assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and
114
+ stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}"""
115
+ # we add a small number to avoid floating point error in the interpolation
116
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
117
+ w0, h0 = w0 + 0.1, h0 + 0.1
118
+ patch_pos_embed = nn.functional.interpolate(
119
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
120
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
121
+ mode='bicubic',
122
+ align_corners=False, recompute_scale_factor=False
123
+ )
124
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
125
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
126
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
127
+
128
+ return interpolate_pos_encoding
129
+
130
+ @staticmethod
131
+ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
132
+ """
133
+ change resolution of model output by changing the stride of the patch extraction.
134
+ :param model: the model to change resolution for.
135
+ :param stride: the new stride parameter.
136
+ :return: the adjusted model
137
+ """
138
+ patch_size = model.patch_embed.patch_size
139
+ if type(patch_size) == tuple:
140
+ patch_size = patch_size[0]
141
+ if stride == patch_size: # nothing to do
142
+ return model
143
+
144
+ stride = nn_utils._pair(stride)
145
+ assert all([(patch_size // s_) * s_ == patch_size for s_ in
146
+ stride]), f'stride {stride} should divide patch_size {patch_size}'
147
+
148
+ # fix the stride
149
+ model.patch_embed.proj.stride = stride
150
+ # fix the positional encoding code
151
+ model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
152
+ return model
153
+
154
+ def preprocess(self, image_path: Union[str, Path],
155
+ load_size: Union[int, Tuple[int, int]] = None, patch_size: int = 14) -> Tuple[torch.Tensor, Image.Image]:
156
+ """
157
+ Preprocesses an image before extraction.
158
+ :param image_path: path to image to be extracted.
159
+ :param load_size: optional. Size to resize image before the rest of preprocessing.
160
+ :return: a tuple containing:
161
+ (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
162
+ (2) the pil image in relevant dimensions
163
+ """
164
+ def divisible_by_num(num, dim):
165
+ return num * (dim // num)
166
+ pil_image = Image.open(image_path).convert('RGB')
167
+ if load_size is not None:
168
+ pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
169
+
170
+ width, height = pil_image.size
171
+ new_width = divisible_by_num(patch_size, width)
172
+ new_height = divisible_by_num(patch_size, height)
173
+ pil_image = pil_image.resize((new_width, new_height), resample=Image.LANCZOS)
174
+
175
+ prep = transforms.Compose([
176
+ transforms.ToTensor(),
177
+ transforms.Normalize(mean=self.mean, std=self.std)
178
+ ])
179
+ prep_img = prep(pil_image)[None, ...]
180
+ return prep_img, pil_image
181
+
182
+ def preprocess_pil(self, pil_image):
183
+ """
184
+ Preprocesses an image before extraction.
185
+ :param image_path: path to image to be extracted.
186
+ :param load_size: optional. Size to resize image before the rest of preprocessing.
187
+ :return: a tuple containing:
188
+ (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
189
+ (2) the pil image in relevant dimensions
190
+ """
191
+ prep = transforms.Compose([
192
+ transforms.ToTensor(),
193
+ transforms.Normalize(mean=self.mean, std=self.std)
194
+ ])
195
+ prep_img = prep(pil_image)[None, ...]
196
+ return prep_img
197
+
198
+ def _get_hook(self, facet: str):
199
+ """
200
+ generate a hook method for a specific block and facet.
201
+ """
202
+ if facet in ['attn', 'token']:
203
+ def _hook(model, input, output):
204
+ self._feats.append(output)
205
+ return _hook
206
+
207
+ if facet == 'query':
208
+ facet_idx = 0
209
+ elif facet == 'key':
210
+ facet_idx = 1
211
+ elif facet == 'value':
212
+ facet_idx = 2
213
+ else:
214
+ raise TypeError(f"{facet} is not a supported facet.")
215
+
216
+ def _inner_hook(module, input, output):
217
+ input = input[0]
218
+ B, N, C = input.shape
219
+ qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
220
+ self._feats.append(qkv[facet_idx]) #Bxhxtxd
221
+ return _inner_hook
222
+
223
+ def _register_hooks(self, layers: List[int], facet: str) -> None:
224
+ """
225
+ register hook to extract features.
226
+ :param layers: layers from which to extract features.
227
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
228
+ """
229
+ for block_idx, block in enumerate(self.model.blocks):
230
+ if block_idx in layers:
231
+ if facet == 'token':
232
+ self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
233
+ elif facet == 'attn':
234
+ self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
235
+ elif facet in ['key', 'query', 'value']:
236
+ self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
237
+ else:
238
+ raise TypeError(f"{facet} is not a supported facet.")
239
+
240
+ def _unregister_hooks(self) -> None:
241
+ """
242
+ unregisters the hooks. should be called after feature extraction.
243
+ """
244
+ for handle in self.hook_handlers:
245
+ handle.remove()
246
+ self.hook_handlers = []
247
+
248
+ def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]:
249
+ """
250
+ extract features from the model
251
+ :param batch: batch to extract features for. Has shape BxCxHxW.
252
+ :param layers: layer to extract. A number between 0 to 11.
253
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
254
+ :return : tensor of features.
255
+ if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
256
+ if facet is 'attn' has shape Bxhxtxt
257
+ if facet is 'token' has shape Bxtxd
258
+ """
259
+ B, C, H, W = batch.shape
260
+ self._feats = []
261
+ self._register_hooks(layers, facet)
262
+ _ = self.model(batch)
263
+ self._unregister_hooks()
264
+ self.load_size = (H, W)
265
+ self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
266
+ return self._feats
267
+
268
+ def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
269
+ """
270
+ create a log-binned descriptor.
271
+ :param x: tensor of features. Has shape Bxhxtxd. [1,6,3410,64]
272
+ :param hierarchy: how many bin hierarchies to use.
273
+ """
274
+ B = x.shape[0]
275
+ num_bins = 1 + 8 * hierarchy
276
+
277
+ bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) [1,3410,384]
278
+ bin_x = bin_x.permute(0, 2, 1)
279
+ bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1])
280
+ # Bx(dxh)xnum_patches[0]xnum_patches[1]
281
+ sub_desc_dim = bin_x.shape[1]
282
+
283
+ avg_pools = []
284
+ # compute bins of all sizes for all spatial locations.
285
+ for k in range(0, hierarchy):
286
+ # avg pooling with kernel 3**kx3**k
287
+ win_size = 3 ** k
288
+ avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
289
+ avg_pools.append(avg_pool(bin_x))
290
+
291
+ bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device)
292
+ for y in range(self.num_patches[0]):
293
+ for x in range(self.num_patches[1]):
294
+ part_idx = 0
295
+ # fill all bins for a spatial location (y, x)
296
+ for k in range(0, hierarchy):
297
+ kernel_size = 3 ** k
298
+ for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
299
+ for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
300
+ if i == y and j == x and k != 0:
301
+ continue
302
+ if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]:
303
+ bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
304
+ :, :, i, j]
305
+ else: # handle padding in a more delicate way than zero padding
306
+ temp_i = max(0, min(i, self.num_patches[0] - 1))
307
+ temp_j = max(0, min(j, self.num_patches[1] - 1))
308
+ bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
309
+ :, :, temp_i,
310
+ temp_j]
311
+ part_idx += 1
312
+ bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
313
+ # Bx1x(t-1)x(dxh)
314
+ return bin_x #[1,1,3410,6528]
315
+
316
+ def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key',
317
+ bin: bool = False, include_cls: bool = False) -> torch.Tensor:
318
+ """
319
+ extract descriptors from the model
320
+ :param batch: batch to extract descriptors for. Has shape BxCxHxW.
321
+ :param layers: layer to extract. A number between 0 to 11.
322
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
323
+ :param bin: apply log binning to the descriptor. default is False.
324
+ :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
325
+ """
326
+ assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
327
+ choose from ['key' | 'query' | 'value' | 'token'] """
328
+ self._extract_features(batch, [layer], facet)
329
+ x = self._feats[0]
330
+ if facet == 'token':
331
+ x.unsqueeze_(dim=1) #Bx1xtxd
332
+ if not include_cls:
333
+ x = x[:, :, 1:, :] # remove cls token
334
+ else:
335
+ assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
336
+ if not bin:
337
+ desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
338
+ else:
339
+ desc = self._log_bin(x)
340
+ return desc
341
+
342
+ def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
343
+ """
344
+ extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer
345
+ in of the CLS token. All values are then normalized to range between 0 and 1.
346
+ :param batch: batch to extract saliency maps for. Has shape BxCxHxW.
347
+ :return: a tensor of saliency maps. has shape Bxt-1
348
+ """
349
+ assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type."
350
+ self._extract_features(batch, [11], 'attn')
351
+ head_idxs = [0, 2, 4, 5]
352
+ curr_feats = self._feats[0] #Bxhxtxt
353
+ cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1)
354
+ temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0]
355
+ cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1]
356
+ return cls_attn_maps
model_utils/projection_network.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from model_utils.resnet import ResNet, BottleneckBlock
5
+ import torch.nn.functional as F
6
+
7
+ class DummyAggregationNetwork(nn.Module): # for testing, return the input
8
+ def __init__(self):
9
+ super(DummyAggregationNetwork, self).__init__()
10
+ # dummy paprameter
11
+ self.dummy = nn.Parameter(torch.ones([]))
12
+ def forward(self, batch, pose=None):
13
+ return batch * self.dummy
14
+
15
+ class AggregationNetwork(nn.Module):
16
+ """
17
+ Module for aggregating feature maps across time and space.
18
+ Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
19
+ https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
20
+ """
21
+ def __init__(
22
+ self,
23
+ device,
24
+ feature_dims=[640,1280,1280,768],
25
+ projection_dim=384,
26
+ num_norm_groups=32,
27
+ save_timestep=[1],
28
+ kernel_size = [1,3,1],
29
+ contrastive_temp = 10,
30
+ feat_map_dropout=0.0,
31
+ num_blocks=None,
32
+ bottleneck_channels=None
33
+ ):
34
+ super().__init__()
35
+ self.skip_connection = True
36
+ self.feat_map_dropout = feat_map_dropout
37
+ self.azimuth_embedding = None
38
+ self.pos_embedding = None
39
+ self.bottleneck_layers = nn.ModuleList()
40
+ self.feature_dims = feature_dims
41
+ self.num_blocks = num_blocks if num_blocks is not None else 1
42
+ self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4
43
+ # For CLIP symmetric cross entropy loss during training
44
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
45
+ self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
46
+ self.device = device
47
+ self.save_timestep = save_timestep
48
+
49
+ self.mixing_weights_names = []
50
+ for l, feature_dim in enumerate(self.feature_dims):
51
+ bottleneck_layer = nn.Sequential(
52
+ *ResNet.make_stage(
53
+ BottleneckBlock,
54
+ num_blocks=self.num_blocks,
55
+ in_channels=feature_dim,
56
+ bottleneck_channels=self.bottleneck_channels,
57
+ out_channels=projection_dim,
58
+ norm="GN",
59
+ num_norm_groups=num_norm_groups,
60
+ kernel_size=kernel_size
61
+ )
62
+ )
63
+ self.bottleneck_layers.append(bottleneck_layer)
64
+ for t in save_timestep:
65
+ # 1-index the layer name following prior work
66
+ self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
67
+ self.last_layer = None
68
+ self.bottleneck_layers = self.bottleneck_layers.to(device)
69
+ mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
70
+ self.mixing_weights = nn.Parameter(mixing_weights.to(device))
71
+ # count number of parameters
72
+ num_params = 0
73
+ for param in self.parameters():
74
+ num_params += param.numel()
75
+ print(f"AggregationNetwork has {num_params} parameters.")
76
+
77
+ def load_pretrained_weights(self, pretrained_dict):
78
+ custom_dict = self.state_dict()
79
+
80
+ # Handle size mismatch
81
+ if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
82
+ # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
83
+ custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
84
+ custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
85
+ else:
86
+ custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
87
+
88
+ # Load the weights that do match
89
+ matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
90
+ custom_dict.update(matching_keys)
91
+
92
+ # Now load the updated state_dict
93
+ self.load_state_dict(custom_dict, strict=False)
94
+
95
+ def forward(self, batch, pose=None):
96
+ """
97
+ Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
98
+ """
99
+ if self.feat_map_dropout > 0 and self.training:
100
+ batch = F.dropout(batch, p=self.feat_map_dropout)
101
+
102
+ output_feature = None
103
+ start = 0
104
+ mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
105
+ if self.pos_embedding is not None: #position embedding
106
+ batch = torch.cat((batch, self.pos_embedding), dim=1)
107
+ for i in range(len(mixing_weights)):
108
+ # Share bottleneck layers across timesteps
109
+ bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
110
+ # Chunk the batch according the layer
111
+ # Account for looping if there are multiple timesteps
112
+ end = start + self.feature_dims[i % len(self.feature_dims)]
113
+ feats = batch[:, start:end, :, :]
114
+ start = end
115
+ # Downsample the number of channels and weight the layer
116
+ bottlenecked_feature = bottleneck_layer(feats)
117
+ bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
118
+ if output_feature is None:
119
+ output_feature = bottlenecked_feature
120
+ else:
121
+ output_feature += bottlenecked_feature
122
+
123
+ if self.last_layer is not None:
124
+
125
+ output_feature_after = self.last_layer(output_feature)
126
+ if self.skip_connection:
127
+ # skip connection
128
+ output_feature = output_feature + output_feature_after
129
+ return output_feature
130
+
131
+
132
+ def conv1x1(in_planes, out_planes, stride=1):
133
+ """1x1 convolution without padding"""
134
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
135
+
136
+
137
+ def conv3x3(in_planes, out_planes, stride=1):
138
+ """3x3 convolution with padding"""
139
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
140
+
141
+
142
+ class BasicBlock(nn.Module):
143
+ def __init__(self, in_planes, planes, stride=1):
144
+ super().__init__()
145
+ self.conv1 = conv3x3(in_planes, planes, stride)
146
+ self.conv2 = conv3x3(planes, planes)
147
+ self.bn1 = nn.BatchNorm2d(planes)
148
+ self.bn2 = nn.BatchNorm2d(planes)
149
+ self.relu = nn.ReLU(inplace=True)
150
+
151
+ if stride == 1:
152
+ self.downsample = None
153
+ else:
154
+ self.downsample = nn.Sequential(
155
+ conv1x1(in_planes, planes, stride=stride),
156
+ nn.BatchNorm2d(planes)
157
+ )
158
+
159
+ def forward(self, x):
160
+ y = x
161
+ y = self.relu(self.bn1(self.conv1(y)))
162
+ y = self.bn2(self.conv2(y))
163
+
164
+ if self.downsample is not None:
165
+ x = self.downsample(x)
166
+
167
+ return self.relu(x+y)
model_utils/resnet.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ # import fvcore.nn.weight_init as weight_init
5
+ import numpy as np
6
+
7
+ """
8
+ Functions for building the BottleneckBlock from Detectron2.
9
+ # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
10
+ """
11
+
12
+ def get_norm(norm, out_channels, num_norm_groups=32):
13
+ """
14
+ Args:
15
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
16
+ or a callable that takes a channel number and returns
17
+ the normalization layer as a nn.Module.
18
+ Returns:
19
+ nn.Module or None: the normalization layer
20
+ """
21
+ if norm is None:
22
+ return None
23
+ if isinstance(norm, str):
24
+ if len(norm) == 0:
25
+ return None
26
+ norm = {
27
+ "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
28
+ }[norm]
29
+ return norm(out_channels)
30
+
31
+ def get_activation(activation):
32
+ """
33
+ Args:
34
+ activation (str or callable): either one of relu, lrelu, prelu, leaky_relu,
35
+ sigmoid, tanh, elu, selu, swish, mish; or a callable that takes a
36
+ tensor and returns a tensor.
37
+ Returns:
38
+ nn.Module or None: the activation layer
39
+ """
40
+ if activation is None:
41
+ return None
42
+ if isinstance(activation, str):
43
+ if len(activation) == 0:
44
+ return None
45
+ activation = {
46
+ "relu": nn.ReLU,
47
+ "lrelu": nn.LeakyReLU,
48
+ "prelu": nn.PReLU,
49
+ "leaky_relu": nn.LeakyReLU,
50
+ "sigmoid": nn.Sigmoid,
51
+ "tanh": nn.Tanh,
52
+ "elu": nn.ELU,
53
+ "selu": nn.SELU,
54
+ }[activation]
55
+ return activation()
56
+
57
+ # SCE crisscross + diags
58
+ class EfficientSpatialContextNet(nn.Module):
59
+ def __init__(self, kernel_size=7, in_channels=768, out_channels=768, use_cuda=True):
60
+ super(EfficientSpatialContextNet, self).__init__()
61
+ self.kernel_size = kernel_size
62
+ self.pad = kernel_size // 2
63
+ self.conv = torch.nn.Conv2d(
64
+ in_channels + 4*self.kernel_size,
65
+ out_channels,
66
+ 1,
67
+ bias=True,
68
+ padding_mode="zeros",
69
+ )
70
+
71
+ if use_cuda:
72
+ self.conv = self.conv.cuda()
73
+
74
+ def forward(self, feature):
75
+ b, c, h, w = feature.size()
76
+ feature_normalized = F.normalize(feature, p=2, dim=1)
77
+ feature_pad = F.pad(
78
+ feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0
79
+ )
80
+ output = torch.zeros(
81
+ [4*self.kernel_size, b, h, w],
82
+ dtype=feature.dtype,
83
+ requires_grad=feature.requires_grad,
84
+ )
85
+ if feature.is_cuda:
86
+ output = output.cuda(feature.get_device())
87
+
88
+ # left-top to right-bottom
89
+ for i in range(self.kernel_size):
90
+ c=i
91
+ r=i
92
+ output[i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
93
+
94
+ # col
95
+ for i in range(self.kernel_size):
96
+ c=self.kernel_size//2
97
+ r=i
98
+ output[1*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
99
+
100
+ # right-top to left-bottom
101
+ for i in range(self.kernel_size):
102
+ c=(self.kernel_size-1)-i
103
+ r=i
104
+ output[2*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
105
+
106
+ # row
107
+ for i in range(self.kernel_size):
108
+ c=i
109
+ r=self.kernel_size//2
110
+ output[3*self.kernel_size+i] = (feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized).sum(1)
111
+
112
+ output = output.transpose(0, 1).contiguous()
113
+ output = torch.cat((feature, output), 1)
114
+ output = self.conv(output)
115
+ # output = F.relu(output)
116
+
117
+ return output
118
+
119
+ def c2_msra_fill(module: nn.Module) -> None:
120
+ """
121
+ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
122
+ Also initializes `module.bias` to 0.
123
+
124
+ Args:
125
+ module (torch.nn.Module): module to initialize.
126
+ """
127
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
128
+ if module.bias is not None:
129
+ nn.init.constant_(module.bias, 0)
130
+
131
+ class Conv2d(nn.Conv2d):
132
+ """
133
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
134
+ """
135
+
136
+ def __init__(self, *args, **kwargs):
137
+ """
138
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
139
+ Args:
140
+ norm (nn.Module, optional): a normalization layer
141
+ activation (callable(Tensor) -> Tensor): a callable activation function
142
+ It assumes that norm layer is used before activation.
143
+ """
144
+ norm = kwargs.pop("norm", None)
145
+ activation = kwargs.pop("activation", None)
146
+ super().__init__(*args, **kwargs)
147
+
148
+ self.norm = norm
149
+ self.activation = activation
150
+
151
+ def forward(self, x):
152
+ x = F.conv2d(
153
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
154
+ )
155
+ if self.norm is not None:
156
+ x = self.norm(x)
157
+ if self.activation is not None:
158
+ x = self.activation(x)
159
+ return x
160
+
161
+ class CNNBlockBase(nn.Module):
162
+ """
163
+ A CNN block is assumed to have input channels, output channels and a stride.
164
+ The input and output of `forward()` method must be NCHW tensors.
165
+ The method can perform arbitrary computation but must match the given
166
+ channels and stride specification.
167
+ Attribute:
168
+ in_channels (int):
169
+ out_channels (int):
170
+ stride (int):
171
+ """
172
+
173
+ def __init__(self, in_channels, out_channels, stride):
174
+ """
175
+ The `__init__` method of any subclass should also contain these arguments.
176
+ Args:
177
+ in_channels (int):
178
+ out_channels (int):
179
+ stride (int):
180
+ """
181
+ super().__init__()
182
+ self.in_channels = in_channels
183
+ self.out_channels = out_channels
184
+ self.stride = stride
185
+
186
+ class BottleneckBlock(CNNBlockBase):
187
+ """
188
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
189
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
190
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ in_channels,
196
+ out_channels,
197
+ *,
198
+ bottleneck_channels,
199
+ stride=1,
200
+ num_groups=1,
201
+ norm="GN",
202
+ stride_in_1x1=False,
203
+ dilation=1,
204
+ num_norm_groups=32,
205
+ kernel_size = (1,3,1)
206
+ ):
207
+ """
208
+ Args:
209
+ bottleneck_channels (int): number of output channels for the 3x3
210
+ "bottleneck" conv layers.
211
+ num_groups (int): number of groups for the 3x3 conv layer.
212
+ norm (str or callable): normalization for all conv layers.
213
+ See :func:`layers.get_norm` for supported format.
214
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
215
+ first 1x1 convolution or the bottleneck 3x3 convolution.
216
+ dilation (int): the dilation rate of the 3x3 conv layer.
217
+ """
218
+ super().__init__(in_channels, out_channels, stride)
219
+
220
+ if in_channels != out_channels:
221
+ self.shortcut = Conv2d(
222
+ in_channels,
223
+ out_channels,
224
+ kernel_size=1,
225
+ stride=stride,
226
+ bias=False,
227
+ norm=get_norm(norm, out_channels, num_norm_groups),
228
+ )
229
+ else:
230
+ self.shortcut = None
231
+
232
+ # The original MSRA ResNet models have stride in the first 1x1 conv
233
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
234
+ # stride in the 3x3 conv
235
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
236
+
237
+ self.conv1 = Conv2d(
238
+ in_channels,
239
+ bottleneck_channels,
240
+ kernel_size=kernel_size[0],
241
+ stride=stride_1x1,
242
+ padding=(kernel_size[0]-1)//2,
243
+ bias=False,
244
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
245
+ )
246
+
247
+ self.conv2 = Conv2d(
248
+ bottleneck_channels,
249
+ bottleneck_channels,
250
+ kernel_size=kernel_size[1],
251
+ stride=stride_3x3,
252
+ padding=dilation*(kernel_size[1]-1)//2,
253
+ bias=False,
254
+ groups=num_groups,
255
+ dilation=dilation,
256
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
257
+ )
258
+
259
+ self.conv3 = Conv2d(
260
+ bottleneck_channels,
261
+ out_channels,
262
+ kernel_size=kernel_size[2],
263
+ bias=False,
264
+ norm=get_norm(norm, out_channels, num_norm_groups),
265
+ )
266
+
267
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
268
+ if layer is not None: # shortcut can be None
269
+ c2_msra_fill(layer)
270
+
271
+ # Zero-initialize the last normalization in each residual branch,
272
+ # so that at the beginning, the residual branch starts with zeros,
273
+ # and each residual block behaves like an identity.
274
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
275
+ # "For BN layers, the learnable scaling coefficient γ is initialized
276
+ # to be 1, except for each residual block's last BN
277
+ # where γ is initialized to be 0."
278
+
279
+ # nn.init.constant_(self.conv3.norm.weight, 0)
280
+ # TODO this somehow hurts performance when training GN models from scratch.
281
+ # Add it as an option when we need to use this code to train a backbone.
282
+
283
+ def forward(self, x):
284
+ out = self.conv1(x)
285
+ out = F.relu_(out)
286
+
287
+ out = self.conv2(out)
288
+ out = F.relu_(out)
289
+
290
+ out = self.conv3(out)
291
+
292
+ if self.shortcut is not None:
293
+ shortcut = self.shortcut(x)
294
+ else:
295
+ shortcut = x
296
+
297
+ out += shortcut
298
+ out = F.relu_(out)
299
+ return out
300
+
301
+ class ResNet(nn.Module):
302
+ """
303
+ Implement :paper:`ResNet`.
304
+ """
305
+
306
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
307
+ """
308
+ Args:
309
+ stem (nn.Module): a stem module
310
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
311
+ each contains multiple :class:`CNNBlockBase`.
312
+ num_classes (None or int): if None, will not perform classification.
313
+ Otherwise, will create a linear layer.
314
+ out_features (list[str]): name of the layers whose outputs should
315
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
316
+ If None, will return the output of the last layer.
317
+ freeze_at (int): The number of stages at the beginning to freeze.
318
+ see :meth:`freeze` for detailed explanation.
319
+ """
320
+ super().__init__()
321
+ self.stem = stem
322
+ self.num_classes = num_classes
323
+
324
+ current_stride = self.stem.stride
325
+ self._out_feature_strides = {"stem": current_stride}
326
+ self._out_feature_channels = {"stem": self.stem.out_channels}
327
+
328
+ self.stage_names, self.stages = [], []
329
+
330
+ if out_features is not None:
331
+ # Avoid keeping unused layers in this module. They consume extra memory
332
+ # and may cause allreduce to fail
333
+ num_stages = max(
334
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
335
+ )
336
+ stages = stages[:num_stages]
337
+ for i, blocks in enumerate(stages):
338
+ assert len(blocks) > 0, len(blocks)
339
+ for block in blocks:
340
+ assert isinstance(block, CNNBlockBase), block
341
+
342
+ name = "res" + str(i + 2)
343
+ stage = nn.Sequential(*blocks)
344
+
345
+ self.add_module(name, stage)
346
+ self.stage_names.append(name)
347
+ self.stages.append(stage)
348
+
349
+ self._out_feature_strides[name] = current_stride = int(
350
+ current_stride * np.prod([k.stride for k in blocks])
351
+ )
352
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
353
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
354
+
355
+ if num_classes is not None:
356
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
357
+ self.linear = nn.Linear(curr_channels, num_classes)
358
+
359
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
360
+ # "The 1000-way fully-connected layer is initialized by
361
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
362
+ nn.init.normal_(self.linear.weight, std=0.01)
363
+ name = "linear"
364
+
365
+ if out_features is None:
366
+ out_features = [name]
367
+ self._out_features = out_features
368
+ assert len(self._out_features)
369
+ children = [x[0] for x in self.named_children()]
370
+ for out_feature in self._out_features:
371
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
372
+ self.freeze(freeze_at)
373
+
374
+ def forward(self, x):
375
+ """
376
+ Args:
377
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
378
+ Returns:
379
+ dict[str->Tensor]: names and the corresponding features
380
+ """
381
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
382
+ outputs = {}
383
+ x = self.stem(x)
384
+ if "stem" in self._out_features:
385
+ outputs["stem"] = x
386
+ for name, stage in zip(self.stage_names, self.stages):
387
+ x = stage(x)
388
+ if name in self._out_features:
389
+ outputs[name] = x
390
+ if self.num_classes is not None:
391
+ x = self.avgpool(x)
392
+ x = torch.flatten(x, 1)
393
+ x = self.linear(x)
394
+ if "linear" in self._out_features:
395
+ outputs["linear"] = x
396
+ return outputs
397
+
398
+ def freeze(self, freeze_at=0):
399
+ """
400
+ Freeze the first several stages of the ResNet. Commonly used in
401
+ fine-tuning.
402
+ Layers that produce the same feature map spatial size are defined as one
403
+ "stage" by :paper:`FPN`.
404
+ Args:
405
+ freeze_at (int): number of stages to freeze.
406
+ `1` means freezing the stem. `2` means freezing the stem and
407
+ one residual stage, etc.
408
+ Returns:
409
+ nn.Module: this ResNet itself
410
+ """
411
+ if freeze_at >= 1:
412
+ self.stem.freeze()
413
+ for idx, stage in enumerate(self.stages, start=2):
414
+ if freeze_at >= idx:
415
+ for block in stage.children():
416
+ block.freeze()
417
+ return self
418
+
419
+ @staticmethod
420
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
421
+ """
422
+ Create a list of blocks of the same type that forms one ResNet stage.
423
+ Args:
424
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
425
+ stage. A module of this type must not change spatial resolution of inputs unless its
426
+ stride != 1.
427
+ num_blocks (int): number of blocks in this stage
428
+ in_channels (int): input channels of the entire stage.
429
+ out_channels (int): output channels of **every block** in the stage.
430
+ kwargs: other arguments passed to the constructor of
431
+ `block_class`. If the argument name is "xx_per_block", the
432
+ argument is a list of values to be passed to each block in the
433
+ stage. Otherwise, the same argument is passed to every block
434
+ in the stage.
435
+ Returns:
436
+ list[CNNBlockBase]: a list of block module.
437
+ Examples:
438
+ ::
439
+ stage = ResNet.make_stage(
440
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
441
+ bottleneck_channels=16, num_groups=1,
442
+ stride_per_block=[2, 1, 1],
443
+ dilations_per_block=[1, 1, 2]
444
+ )
445
+ Usually, layers that produce the same feature map spatial size are defined as one
446
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
447
+ all be 1.
448
+ """
449
+ blocks = []
450
+ for i in range(num_blocks):
451
+ curr_kwargs = {}
452
+ for k, v in kwargs.items():
453
+ if k.endswith("_per_block"):
454
+ assert len(v) == num_blocks, (
455
+ f"Argument '{k}' of make_stage should have the "
456
+ f"same length as num_blocks={num_blocks}."
457
+ )
458
+ newk = k[: -len("_per_block")]
459
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
460
+ curr_kwargs[newk] = v[i]
461
+ else:
462
+ curr_kwargs[k] = v
463
+
464
+ blocks.append(
465
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
466
+ )
467
+ in_channels = out_channels
468
+ return blocks
469
+
470
+ @staticmethod
471
+ def make_default_stages(depth, block_class=None, **kwargs):
472
+ """
473
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
474
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
475
+ instead for fine-grained customization.
476
+ Args:
477
+ depth (int): depth of ResNet
478
+ block_class (type): the CNN block class. Has to accept
479
+ `bottleneck_channels` argument for depth > 50.
480
+ By default it is BasicBlock or BottleneckBlock, based on the
481
+ depth.
482
+ kwargs:
483
+ other arguments to pass to `make_stage`. Should not contain
484
+ stride and channels, as they are predefined for each depth.
485
+ Returns:
486
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
487
+ :class:`ResNet.__init__`.
488
+ """
489
+ num_blocks_per_stage = {
490
+ 18: [2, 2, 2, 2],
491
+ 34: [3, 4, 6, 3],
492
+ 50: [3, 4, 6, 3],
493
+ 101: [3, 4, 23, 3],
494
+ 152: [3, 8, 36, 3],
495
+ }[depth]
496
+ if block_class is None:
497
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
498
+ if depth < 50:
499
+ in_channels = [64, 64, 128, 256]
500
+ out_channels = [64, 128, 256, 512]
501
+ else:
502
+ in_channels = [64, 256, 512, 1024]
503
+ out_channels = [256, 512, 1024, 2048]
504
+ ret = []
505
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
506
+ if depth >= 50:
507
+ kwargs["bottleneck_channels"] = o // 4
508
+ ret.append(
509
+ ResNet.make_stage(
510
+ block_class=block_class,
511
+ num_blocks=n,
512
+ stride_per_block=[s] + [1] * (n - 1),
513
+ in_channels=i,
514
+ out_channels=o,
515
+ **kwargs,
516
+ )
517
+ )
518
+ return ret
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ torchvision
4
+ pillow
5
+ gradio
6
+ matplotlib