AnySplat / src /model /encoder /backbone /backbone_croco_multiview.py
alexnasa's picture
Upload 243 files
2568013 verified
from copy import deepcopy
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange
from torch import nn
from .croco.blocks import DecoderBlock
from .croco.croco import CroCoNet
from .croco.misc import fill_default_args, freeze_all_params, transpose_to_landscape, is_symmetrized, interleave, \
make_batch_symmetric
from .croco.patch_embed import get_patch_embed
from .backbone import Backbone
from src.geometry.camera_emb import get_intrinsic_embedding
inf = float('inf')
croco_params = {
'ViTLarge_BaseDecoder': {
'enc_depth': 24,
'dec_depth': 12,
'enc_embed_dim': 1024,
'dec_embed_dim': 768,
'enc_num_heads': 16,
'dec_num_heads': 12,
'pos_embed': 'RoPE100',
'img_size': (512, 512),
},
}
default_dust3r_params = {
'enc_depth': 24,
'dec_depth': 12,
'enc_embed_dim': 1024,
'dec_embed_dim': 768,
'enc_num_heads': 16,
'dec_num_heads': 12,
'pos_embed': 'RoPE100',
'patch_embed_cls': 'PatchEmbedDust3R',
'img_size': (512, 512),
'head_type': 'dpt',
'output_mode': 'pts3d',
'depth_mode': ('exp', -inf, inf),
'conf_mode': ('exp', 1, inf)
}
@dataclass
class BackboneCrocoCfg:
name: Literal["croco"]
model: Literal["ViTLarge_BaseDecoder", "ViTBase_SmallDecoder", "ViTBase_BaseDecoder"] # keep interface for the last two models, but they are not supported
patch_embed_cls: str = 'PatchEmbedDust3R' # PatchEmbedDust3R or ManyAR_PatchEmbed
asymmetry_decoder: bool = True
intrinsics_embed_loc: Literal["encoder", "decoder", "none"] = 'none'
intrinsics_embed_degree: int = 0
intrinsics_embed_type: Literal["pixelwise", "linear", "token"] = 'token' # linear or dpt
class AsymmetricCroCoMulti(CroCoNet):
""" Two siamese encoders, followed by two decoders.
The goal is to output 3d points directly, both images in view1's frame
(hence the asymmetry).
"""
def __init__(self, cfg: BackboneCrocoCfg, d_in: int) -> None:
self.intrinsics_embed_loc = cfg.intrinsics_embed_loc
self.intrinsics_embed_degree = cfg.intrinsics_embed_degree
self.intrinsics_embed_type = cfg.intrinsics_embed_type
self.intrinsics_embed_encoder_dim = 0
self.intrinsics_embed_decoder_dim = 0
if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
self.intrinsics_embed_encoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
elif self.intrinsics_embed_loc == 'decoder' and self.intrinsics_embed_type == 'pixelwise':
self.intrinsics_embed_decoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3
self.patch_embed_cls = cfg.patch_embed_cls
self.croco_args = fill_default_args(croco_params[cfg.model], CroCoNet.__init__)
super().__init__(**croco_params[cfg.model])
if cfg.asymmetry_decoder:
self.dec_blocks2 = deepcopy(self.dec_blocks) # This is used in DUSt3R and MASt3R
if self.intrinsics_embed_type == 'linear' or self.intrinsics_embed_type == 'token':
self.intrinsic_encoder = nn.Linear(9, 1024)
# self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)
# self.set_freeze(freeze)
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768, in_chans=3):
in_chans = in_chans + self.intrinsics_embed_encoder_dim
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans)
def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
self.dec_depth = dec_depth
self.dec_embed_dim = dec_embed_dim
# transfer from encoder to decoder
enc_embed_dim = enc_embed_dim + self.intrinsics_embed_decoder_dim
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
# transformer for the decoder
self.dec_blocks = nn.ModuleList([
DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
for i in range(dec_depth)])
# final norm layer
self.dec_norm = norm_layer(dec_embed_dim)
def load_state_dict(self, ckpt, **kw):
# duplicate all weights for the second decoder if not present
new_ckpt = dict(ckpt)
if not any(k.startswith('dec_blocks2') for k in ckpt):
for key, value in ckpt.items():
if key.startswith('dec_blocks'):
new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
return super().load_state_dict(new_ckpt, **kw)
def set_freeze(self, freeze): # this is for use by downstream models
assert freeze in ['none', 'mask', 'encoder'], f"unexpected freeze={freeze}"
to_be_frozen = {
'none': [],
'mask': [self.mask_token],
'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],
'encoder_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.enc_norm, self.decoder_embed, self.dec_blocks, self.dec_blocks2, self.dec_norm],
}
freeze_all_params(to_be_frozen[freeze])
def _set_prediction_head(self, *args, **kwargs):
""" No prediction head """
return
def _encode_image(self, image, true_shape, intrinsics_embed=None):
# embed the image into patches (x has size B x Npatches x C)
x, pos = self.patch_embed(image, true_shape=true_shape)
if intrinsics_embed is not None:
if self.intrinsics_embed_type == 'linear':
x = x + intrinsics_embed
elif self.intrinsics_embed_type == 'token':
x = torch.cat((x, intrinsics_embed), dim=1)
add_pose = pos[:, 0:1, :].clone()
add_pose[:, :, 0] += (pos[:, -1, 0].unsqueeze(-1) + 1)
pos = torch.cat((pos, add_pose), dim=1)
# add positional embedding without cls token
assert self.enc_pos_embed is None
# now apply the transformer encoder and normalization
for blk in self.enc_blocks:
x = blk(x, pos)
x = self.enc_norm(x)
return x, pos, None
def _decoder(self, feat, pose, extra_embed=None):
b, v, l, c = feat.shape
final_output = [feat] # before projection
if extra_embed is not None:
feat = torch.cat((feat, extra_embed), dim=-1)
# project to decoder dim
f = rearrange(feat, "b v l c -> (b v) l c")
f = self.decoder_embed(f)
f = rearrange(f, "(b v) l c -> b v l c", b=b, v=v)
final_output.append(f)
def generate_ctx_views(x):
b, v, l, c = x.shape
ctx_views = x.unsqueeze(1).expand(b, v, v, l, c)
mask = torch.arange(v).unsqueeze(0) != torch.arange(v).unsqueeze(1)
ctx_views = ctx_views[:, mask].reshape(b, v, v - 1, l, c) # B, V, V-1, L, C
ctx_views = ctx_views.flatten(2, 3) # B, V, (V-1)*L, C
return ctx_views.contiguous()
pos_ctx = generate_ctx_views(pose)
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
feat_current = final_output[-1]
feat_current_ctx = generate_ctx_views(feat_current)
# img1 side
f1, _ = blk1(feat_current[:, 0].contiguous(), feat_current_ctx[:, 0].contiguous(), pose[:, 0].contiguous(), pos_ctx[:, 0].contiguous())
f1 = f1.unsqueeze(1)
# img2 side
f2, _ = blk2(rearrange(feat_current[:, 1:], "b v l c -> (b v) l c"),
rearrange(feat_current_ctx[:, 1:], "b v l c -> (b v) l c"),
rearrange(pose[:, 1:], "b v l c -> (b v) l c"),
rearrange(pos_ctx[:, 1:], "b v l c -> (b v) l c"))
f2 = rearrange(f2, "(b v) l c -> b v l c", b=b, v=v-1)
# store the result
final_output.append(torch.cat((f1, f2), dim=1))
# normalize last output
del final_output[1] # duplicate with final_output[0]
last_feat = rearrange(final_output[-1], "b v l c -> (b v) l c")
last_feat = self.dec_norm(last_feat)
final_output[-1] = rearrange(last_feat, "(b v) l c -> b v l c", b=b, v=v)
return final_output
def forward(self,
context: dict,
symmetrize_batch=False,
return_views=False,
):
b, v, _, h, w = context["image"].shape
images_all = context["image"]
# camera embedding in the encoder
if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise':
intrinsic_embedding = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree)
images_all = torch.cat((images_all, intrinsic_embedding), dim=2)
intrinsic_embedding_all = None
if self.intrinsics_embed_loc == 'encoder' and (self.intrinsics_embed_type == 'token' or self.intrinsics_embed_type == 'linear'):
intrinsic_embedding = self.intrinsic_encoder(context["intrinsics"].flatten(2))
intrinsic_embedding_all = rearrange(intrinsic_embedding, "b v c -> (b v) c").unsqueeze(1)
# step 1: encoder input images
images_all = rearrange(images_all, "b v c h w -> (b v) c h w")
shape_all = torch.tensor(images_all.shape[-2:])[None].repeat(b*v, 1)
feat, pose, _ = self._encode_image(images_all, shape_all, intrinsic_embedding_all)
feat = rearrange(feat, "(b v) l c -> b v l c", b=b, v=v)
pose = rearrange(pose, "(b v) l c -> b v l c", b=b, v=v)
# step 2: decoder
dec_feat = self._decoder(feat, pose)
shape = rearrange(shape_all, "(b v) c -> b v c", b=b, v=v)
images = rearrange(images_all, "(b v) c h w -> b v c h w", b=b, v=v)
if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'token':
dec_feat = list(dec_feat)
for i in range(len(dec_feat)):
dec_feat[i] = dec_feat[i][:, :, :-1]
return dec_feat, shape, images
@property
def patch_size(self) -> int:
return 16
@property
def d_out(self) -> int:
return 1024