File size: 12,154 Bytes
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from typing import Dict, List
from math import sqrt, log
import torch
import torch.nn as nn
import torch.nn.functional as F

from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
from utils import get_model


class MaskFormer(nn.Module):
    def __init__(
            self,
            n_queries: int = 100,
            arch: str = "vit_small",
            patch_size: int = 8,
            training_method: str = "dino",
            n_decoder_layers: int = 6,
            normalize_before: bool = False,
            return_intermediate: bool = False,
            learnable_pixel_decoder: bool = False,
            lateral_connection: bool = False,
            scale_factor: int = 2,
            abs_2d_pe_init: bool = False,
            use_binary_classifier: bool = False
    ):
        """Define a encoder and decoder along with queries to be learned through the decoder."""
        super(MaskFormer, self).__init__()

        if arch == "vit_small":
            self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
            n_dims: int = self.encoder.n_embs
            n_heads: int = self.encoder.n_heads
            mlp_ratio: int = self.encoder.mlp_ratio
        else:
            self.encoder = get_model(arch=arch, training_method=training_method)
            n_dims_resnet: int = self.encoder.n_embs
            n_dims: int = 384
            n_heads: int = 6
            mlp_ratio: int = 4
            self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)

        decoder_layer = TransformerDecoderLayer(
            n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
        )
        self.decoder = TransformerDecoder(
            decoder_layer,
            n_decoder_layers,
            norm=nn.LayerNorm(n_dims),
            return_intermediate=return_intermediate
        )

        self.query_embed = nn.Embedding(n_queries, n_dims).weight  # initialized with gaussian(0, 1)

        if use_binary_classifier:
            # self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
            # self.linear_classifier = nn.Linear(n_dims, 1)
            self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
            # self.norm = nn.LayerNorm(n_dims)
        else:
            # self.ffn = None
            # self.linear_classifier = None
            # self.norm = None
            self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
            self.linear_classifier = nn.Linear(n_dims, 2)
            self.norm = nn.LayerNorm(n_dims)

        self.arch = arch
        self.use_binary_classifier = use_binary_classifier
        self.lateral_connection = lateral_connection
        self.learnable_pixel_decoder = learnable_pixel_decoder
        self.scale_factor = scale_factor

    # copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
    @staticmethod
    def positional_encoding_2d(n_dims: int, height: int, width: int):
        """
        :param n_dims: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        :return: d_model*height*width position matrix
        """
        if n_dims % 4 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dimension (got dim={:d})".format(n_dims))
        pe = torch.zeros(n_dims, height, width)
        # Each dimension use half of d_model
        d_model = int(n_dims / 2)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)
        pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

        return pe

    def forward_encoder(self, x: torch.Tensor):
        """
        :param x: b x c x h x w
        :return patch_tokens: b x depth x hw x n_dims
        """
        if self.arch == "vit_small":
            encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x)  # [:, 1:, :]
            all_patch_tokens: List[torch.Tensor] = list()
            for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
                patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :]  # b x hw x n_dims
                all_patch_tokens.append(patch_tokens)

            all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0)  # depth x b x hw x n_dims
            all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2)  # b x depth x n_dims x hw
            return all_patch_tokens
        else:
            encoder_outputs = self.linear_layer(self.encoder(x)[-1])  # b x n_dims x h x w
            return encoder_outputs

    def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
        """Forward transformer decoder given patch tokens from the encoder's last layer.
        :param patch_tokens: b x n_dims x hw -> hw x b x n_dims
        :param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
        between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
        experiment.
        :return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
        """
        b = patch_tokens.shape[0]
        patch_tokens = patch_tokens.permute(2, 0, 1)  # b x n_dims x hw -> hw x b x n_dims

        # n_queries x n_dims -> n_queries x b x n_dims
        queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
        queries: torch.Tensor = self.decoder.forward(
            tgt=torch.zeros_like(queries),
            memory=patch_tokens,
            query_pos=queries
        ).squeeze(dim=0)

        if len(queries.shape) == 3:
            queries: torch.Tensor = queries.permute(1, 0, 2)  # n_queries x b x n_dims -> b x n_queries x n_dims
        elif len(queries.shape) == 4:
            # n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
            queries: torch.Tensor = queries.permute(2, 0, 1, 3)
        return queries

    def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
        """ Upsample patch tokens by self.scale_factor and produce mask predictions
        :param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
        :param queries: b x n_queries x n_dims
        :return mask_predictions: b x n_queries x h x w
        """

        if input_size is None:
            # assume square shape features
            hw = patch_tokens.shape[-1]
            h = w = int(sqrt(hw))
        else:
            # arbitrary shape features
            h, w = input_size
        patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)

        assert len(patch_tokens.shape) == 4
        patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
        return patch_tokens

    def forward(self, x, encoder_only=False, skip_decoder: bool = False):
        """
        x: b x c x h x w
        patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
        query_emb: n_queries x n_dims -> n_queries x b x n_dims
        """
        dict_outputs: dict = dict()

        # b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
        features: torch.Tensor = self.forward_encoder(x)

        if self.arch == "vit_small":
            # extract the last layer for decoder input
            last_layer_features: torch.Tensor = features[:, -1, ...]  # b x n_dims x hw
        else:
            # transform the shape of the features to the one compatible with transformer decoder
            b, n_dims, h, w = features.shape
            last_layer_features: torch.Tensor = features.view(b, n_dims, h * w)  # b x n_dims x hw

        if encoder_only:
            _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
            _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size

            b, n_dims, hw = last_layer_features.shape
            dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
            return dict_outputs

        # transformer decoder forward
        queries: torch.Tensor = self.forward_transformer_decoder(
            last_layer_features,
            skip_decoder=skip_decoder
        )  # b x n_queries x n_dims or b x n_layers x n_queries x n_dims

        # pixel decoder forward (upsampling the patch tokens by self.scale_factor)
        if self.arch == "vit_small":
            _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
            _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
        else:
            _h, _w = h, w
        features: torch.Tensor = self.forward_pixel_decoder(
            patch_tokens=features if self.lateral_connection else last_layer_features,
            input_size=(_h, _w)
        )  # b x n_dims x h x w

        # queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
        # features: b x n_dims x h x w
        # mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
        if len(queries.shape) == 3:
            mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
        else:
            if self.use_binary_classifier:
                mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
            else:
                mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))

        if self.use_binary_classifier:
            # queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
            queries = queries.permute(1, 0, 2, 3)
            objectness: List[torch.Tensor] = list()
            for n_layer, queries_per_layer in enumerate(queries):  # queries_per_layer: b x n_queries x n_dims
                # objectness_per_layer = self.linear_classifier(
                #     self.ffn(self.norm(queries_per_layer))
                # )  # b x n_queries x 1
                objectness_per_layer = self.ffn(queries_per_layer)  # b x n_queries x 1
                objectness.append(objectness_per_layer)
            # n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
            objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
            dict_outputs.update({
                "objectness": torch.sigmoid(objectness),
                "mask_pred": mask_pred
            })

        return dict_outputs


class MLP(nn.Module):
    """Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
        super(UpsampleBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.GroupNorm(n_groups, out_channels),
            nn.ReLU()
        )
        self.scale_factor = scale_factor

    def forward(self, x):
        return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")