# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # CroCo model during pretraining # -------------------------------------------------------- import torch import torch.nn as nn torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 from functools import partial from models.blocks import Block, DecoderBlock, PatchEmbed from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D from models.masking import RandomMask from transformers import PretrainedConfig from transformers import PreTrainedModel class CrocoConfig(PretrainedConfig): model_type = "croco" def __init__( self, img_size=224, # input image size patch_size=16, # patch_size mask_ratio=0.9, # ratios of masked tokens enc_embed_dim=768, # encoder feature dimension enc_depth=12, # encoder depth enc_num_heads=12, # encoder number of heads in the transformer block dec_embed_dim=512, # decoder feature dimension dec_depth=8, # decoder depth dec_num_heads=16, # decoder number of heads in the transformer block mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder pos_embed="cosine", # positional embedding (either cosine or RoPE100) ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.mask_ratio = mask_ratio self.enc_embed_dim = enc_embed_dim self.enc_depth = enc_depth self.enc_num_heads = enc_num_heads self.dec_embed_dim = dec_embed_dim self.dec_depth = dec_depth self.dec_num_heads = dec_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.norm_im2_in_dec = norm_im2_in_dec self.pos_embed = pos_embed class CroCoNet(PreTrainedModel): config_class = CrocoConfig base_model_prefix = "croco" def __init__(self, config: CrocoConfig): super().__init__(config) # patch embeddings (with initialization done as in MAE) self._set_patch_embed(config.img_size, config.patch_size, config.enc_embed_dim) # mask generations self._set_mask_generator(self.patch_embed.num_patches, config.mask_ratio) self.pos_embed = config.pos_embed if config.pos_embed == "cosine": # positional embedding of the encoder enc_pos_embed = get_2d_sincos_pos_embed( config.enc_embed_dim, int(self.patch_embed.num_patches**0.5), n_cls_token=0, ) self.register_buffer( "enc_pos_embed", torch.from_numpy(enc_pos_embed).float() ) # positional embedding of the decoder dec_pos_embed = get_2d_sincos_pos_embed( config.dec_embed_dim, int(self.patch_embed.num_patches**0.5), n_cls_token=0, ) self.register_buffer( "dec_pos_embed", torch.from_numpy(dec_pos_embed).float() ) # pos embedding in each block self.rope = None # nothing for cosine elif config.pos_embed.startswith("RoPE"): # eg RoPE100 self.enc_pos_embed = None # nothing to add in the encoder with RoPE self.dec_pos_embed = None # nothing to add in the decoder with RoPE if RoPE2D is None: raise ImportError( "Cannot find cuRoPE2D, please install it following the README instructions" ) freq = float(config.pos_embed[len("RoPE") :]) self.rope = RoPE2D(freq=freq) else: raise NotImplementedError("Unknown pos_embed " + config.pos_embed) # transformer for the encoder self.enc_depth = config.enc_depth self.enc_embed_dim = config.enc_embed_dim self.enc_blocks = nn.ModuleList( [ Block( config.enc_embed_dim, config.enc_num_heads, config.mlp_ratio, qkv_bias=True, norm_layer=config.norm_layer, rope=self.rope, ) for i in range(config.enc_depth) ] ) self.enc_norm = config.norm_layer(config.enc_embed_dim) # masked tokens # self._set_mask_token(config.dec_embed_dim) self.mask_token = None # decoder self._set_decoder( config.enc_embed_dim, config.dec_embed_dim, config.dec_num_heads, config.dec_depth, config.mlp_ratio, config.norm_layer, config.norm_im2_in_dec, ) # prediction head self._set_prediction_head(config.dec_embed_dim, config.patch_size) # initializer weights self.initialize_weights() def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) def _set_mask_generator(self, num_patches, mask_ratio): self.mask_generator = RandomMask(num_patches, mask_ratio) def _set_mask_token(self, dec_embed_dim): self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) def _set_decoder( self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec, ): self.dec_depth = dec_depth self.dec_embed_dim = dec_embed_dim # transfer from encoder to decoder self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) # transformer for the decoder self.dec_blocks = nn.ModuleList( [ DecoderBlock( dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope, ) for i in range(dec_depth) ] ) # final norm layer self.dec_norm = norm_layer(dec_embed_dim) def _set_prediction_head(self, dec_embed_dim, patch_size): self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) def initialize_weights(self): # patch embed self.patch_embed._init_weights() # mask tokens if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=0.02) # linears and layer norms self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _encode_image(self, image, do_mask=False, return_all_blocks=False): """ image has B x 3 x img_size x img_size do_mask: whether to perform masking or not return_all_blocks: if True, return the features at the end of every block instead of just the features from the last block (eg for some prediction heads) """ # embed the image into patches (x has size B x Npatches x C) # and get position if each return patch (pos has size B x Npatches x 2) x, pos = self.patch_embed(image) # add positional embedding without cls token if self.enc_pos_embed is not None: x = x + self.enc_pos_embed[None, ...] # apply masking B, N, C = x.size() if do_mask: masks = self.mask_generator(x) x = x[~masks].view(B, -1, C) posvis = pos[~masks].view(B, -1, 2) else: B, N, C = x.size() masks = torch.zeros((B, N), dtype=bool) posvis = pos # now apply the transformer encoder and normalization if return_all_blocks: out = [] for blk in self.enc_blocks: x = blk(x, posvis) out.append(x) out[-1] = self.enc_norm(out[-1]) return out, pos, masks else: for blk in self.enc_blocks: x = blk(x, posvis) x = self.enc_norm(x) return x, pos, masks def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): """ return_all_blocks: if True, return the features at the end of every block instead of just the features from the last block (eg for some prediction heads) masks1 can be None => assume image1 fully visible """ # encoder to decoder layer visf1 = self.decoder_embed(feat1) f2 = self.decoder_embed(feat2) # append masked tokens to the sequence B, Nenc, C = visf1.size() if masks1 is None: # downstreams f1_ = visf1 else: # pretraining Ntotal = masks1.size(1) f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) f1_[~masks1] = visf1.view(B * Nenc, C) # add positional embedding if self.dec_pos_embed is not None: f1_ = f1_ + self.dec_pos_embed f2 = f2 + self.dec_pos_embed # apply Transformer blocks out = f1_ out2 = f2 if return_all_blocks: _out, out = out, [] for blk in self.dec_blocks: _out, out2 = blk(_out, out2, pos1, pos2) out.append(_out) out[-1] = self.dec_norm(out[-1]) else: for blk in self.dec_blocks: out, out2 = blk(out, out2, pos1, pos2) out = self.dec_norm(out) return out def patchify(self, imgs): """ imgs: (B, 3, H, W) x: (B, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum("nchpwq->nhwpqc", x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x, channels=3): """ x: (N, L, patch_size**2 *channels) imgs: (N, 3, H, W) """ patch_size = self.patch_embed.patch_size[0] h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) return imgs # def forward(self, img1, img2): # """ # img1: tensor of size B x 3 x img_size x img_size # img2: tensor of size B x 3 x img_size x img_size # out will be B x N x (3*patch_size*patch_size) # masks are also returned as B x N just in case # """ # # encoder of the masked first image # feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) # # encoder of the second image # feat2, pos2, _ = self._encode_image(img2, do_mask=False) # # decoder # decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) # # prediction head # out = self.prediction_head(decfeat) # # get target # target = self.patchify(img1) # return out, mask1, target