AnySplat / src /model /encoder /heads /linear_gs_head.py
alexnasa's picture
Upload 243 files
2568013 verified
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# dpt head implementation for DUST3R
# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
# the forward function also takes as input a dictionnary img_info with key "height" and "width"
# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
# --------------------------------------------------------
from einops import rearrange
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# import dust3r.utils.path_to_croco
from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block
from .head_modules import UnetExtractor, AppearanceTransformer, _init_weights
from .postprocess import postprocess
import torchvision
def custom_interpolate(
x: torch.Tensor,
size: Tuple[int, int] = None,
scale_factor: float = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
"""
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
"""
if size is None:
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
if input_elements > INT_MAX:
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
interpolated_chunks = [
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
]
x = torch.cat(interpolated_chunks, dim=0)
return x.contiguous()
else:
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
# class DPTOutputAdapter_fix(DPTOutputAdapter):
# """
# Adapt croco's DPTOutputAdapter implementation for dust3r:
# remove duplicated weigths, and fix forward for dust3r
# """
#
# def init(self, dim_tokens_enc=768):
# super().init(dim_tokens_enc)
# # these are duplicated weights
# del self.act_1_postprocess
# del self.act_2_postprocess
# del self.act_3_postprocess
# del self.act_4_postprocess
#
# self.scratch.refinenet1 = make_fusion_block(256 * 2, False, 1, expand=True)
# self.scratch.refinenet2 = make_fusion_block(256 * 2, False, 1, expand=True)
# self.scratch.refinenet3 = make_fusion_block(256 * 2, False, 1, expand=True)
# # self.scratch.refinenet4 = make_fusion_block(256 * 2, False, 1)
#
# self.depth_encoder = UnetExtractor(in_channel=3)
# self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
# self.out_conv = nn.Conv2d(256+3+4, 256, kernel_size=3, padding=1)
# self.out_relu = nn.ReLU(inplace=True)
#
# self.input_merger = nn.Sequential(
# # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
# nn.Conv2d(256+3+3, 256, kernel_size=3, padding=1),
# nn.ReLU(),
# )
#
# def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
# assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
# # H, W = input_info['image_size']
# image_size = self.image_size if image_size is None else image_size
# H, W = image_size
# # Number of patches in height and width
# N_H = H // (self.stride_level * self.P_H)
# N_W = W // (self.stride_level * self.P_W)
#
# # Hook decoder onto 4 layers from specified ViT layers
# layers = [encoder_tokens[hook] for hook in self.hooks]
#
# # Extract only task-relevant tokens and ignore global tokens.
# layers = [self.adapt_tokens(l) for l in layers]
#
# # Reshape tokens to spatial representation
# layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
#
# layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
# # Project layers to chosen feature dim
# layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
#
# # get depth features
# depth_features = self.depth_encoder(depths)
# depth_feature1, depth_feature2, depth_feature3 = depth_features
#
# # Fuse layers using refinement stages
# path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
# path_3 = self.scratch.refinenet3(torch.cat([path_4, depth_feature3], dim=1), torch.cat([layers[2], depth_feature3], dim=1))
# path_2 = self.scratch.refinenet2(torch.cat([path_3, depth_feature2], dim=1), torch.cat([layers[1], depth_feature2], dim=1))
# path_1 = self.scratch.refinenet1(torch.cat([path_2, depth_feature1], dim=1), torch.cat([layers[0], depth_feature1], dim=1))
# # path_3 = self.scratch.refinenet3(path_4, layers[2], depth_feature3)
# # path_2 = self.scratch.refinenet2(path_3, layers[1], depth_feature2)
# # path_1 = self.scratch.refinenet1(path_2, layers[0], depth_feature1)
#
# path_1 = self.feat_up(path_1)
# path_1 = torch.cat([path_1, imgs, depths], dim=1)
# if conf is not None:
# path_1 = torch.cat([path_1, conf], dim=1)
# path_1 = self.input_merger(path_1)
#
# # Output head
# out = self.head(path_1)
#
# return out
class DPTOutputAdapter_fix(DPTOutputAdapter):
"""
Adapt croco's DPTOutputAdapter implementation for dust3r:
remove duplicated weigths, and fix forward for dust3r
"""
def init(self, dim_tokens_enc=768):
super().init(dim_tokens_enc)
# these are duplicated weights
del self.act_1_postprocess
del self.act_2_postprocess
del self.act_3_postprocess
del self.act_4_postprocess
self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
# self.input_merger = nn.Sequential(
# # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1),
# # nn.Conv2d(3+6, 256, 7, 1, 3),
# nn.Conv2d(3, 256, 7, 1, 3),
# nn.ReLU(),
# )
def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None):
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
# H, W = input_info['image_size']
image_size = self.image_size if image_size is None else image_size
H, W = image_size
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
# Hook decoder onto 4 layers from specified ViT layers
layers = [encoder_tokens[hook] for hook in self.hooks]
# Extract only task-relevant tokens and ignore global tokens.
layers = [self.adapt_tokens(l) for l in layers]
# Reshape tokens to spatial representation
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
# Project layers to chosen feature dim
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
# Fuse layers using refinement stages
path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
path_3 = self.scratch.refinenet3(path_4, layers[2])
path_2 = self.scratch.refinenet2(path_3, layers[1])
path_1 = self.scratch.refinenet1(path_2, layers[0])
# direct_img_feat = self.input_merger(imgs)
# actually, we just do interpolate here
# path_1 = self.feat_up(path_1)
path_1 = custom_interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True)
# path_1 = F.interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True)
# path_1 = path_1 + direct_img_feat
# path_1 = torch.cat([path_1, imgs], dim=1)
# Output head
# out = self.head(path_1)
out = path_1
return out, [path_4, path_3, path_2]
class PixelwiseTaskWithDPT(nn.Module):
""" DPT module for dust3r, can return 3D points + confidence for all pixels"""
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
super(PixelwiseTaskWithDPT, self).__init__()
self.return_all_layers = True # backbone needs to return all layers
self.postprocess = postprocess
self.depth_mode = depth_mode
self.conf_mode = conf_mode
assert n_cls_token == 0, "Not implemented"
dpt_args = dict(output_width_ratio=output_width_ratio,
num_channels=num_channels,
**kwargs)
if hooks_idx is not None:
dpt_args.update(hooks=hooks_idx)
self.dpt = DPTOutputAdapter_fix(**dpt_args)
dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
self.dpt.init(**dpt_init_args)
def forward(self, x, depths, imgs, img_info, conf=None):
out, interm_feats = self.dpt(x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf)
if self.postprocess:
out = self.postprocess(out, self.depth_mode, self.conf_mode)
return out, interm_feats
class AttnBasedAppearanceHead(nn.Module):
"""
Attention head Appearence Reconstruction
"""
def __init__(self, num_channels, patch_size, feature_dim, last_dim, hooks_idx, dim_tokens, postprocess, depth_mode, conf_mode, head_type='gs_params'):
super().__init__()
self.num_channels = num_channels
self.patch_size = patch_size
self.hooks = hooks_idx
assert len(set(dim_tokens)) == 1
self.tokenizer = nn.Linear(3 * self.patch_size[0] ** 2 + 512, dim_tokens[0], bias=False)
self.C_feat = 128
self.vgg_feature_extractor = torchvision.models.vgg16(pretrained=True).features
# Freeze the VGG parameters
for param in self.vgg_feature_extractor.parameters():
param.requires_grad = False
self.token_decoder = nn.Sequential(
nn.Linear(dim_tokens[0] * (len(self.hooks) + 1), self.C_feat * (self.patch_size[0] ** 2)),
nn.SiLU(),
nn.Linear(self.C_feat * (self.patch_size[0] ** 2), self.C_feat * (self.patch_size[0] ** 2)),
)
self.pixel_linear = nn.Linear(self.C_feat, self.num_channels)
def img_pts_tokenizer(self, imgs):
_, _, H, W = imgs.shape
# Process images through VGG to extract features
# imgs = imgs.permute(0, 2, 3, 1).contiguous()
with torch.no_grad():
vgg_features = self.vgg_feature_extractor(imgs)
# 1. concat original images with vgg features and then patchify
vgg_features = F.interpolate(vgg_features, size=(H, W), mode='bilinear', align_corners=False)
combined = torch.cat([imgs, vgg_features], dim=1) # [B, C+512, H, W]
combined = combined.permute(0, 2, 3, 1).contiguous()
patch_size = self.patch_size
hh = H // patch_size[0]
ww = W // patch_size[1]
input_patches = rearrange(combined, "b (hh ph) (ww pw) c -> b (hh ww) (ph pw c)",
hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
input_tokens = self.tokenizer(input_patches)
# 2. only use vgg features, use a shallow conv to get the token
# # Combine original images with VGG features
# imgs = torch.cat([imgs, vgg_features], dim=1)
# imgs = imgs.permute(0, 2, 3, 1).flatten(1, 2).contiguous()
# # Pachify
# patch_size = self.patch_size
# hh = H // patch_size[0]
# ww = W // patch_size[1]
# input = rearrange(imgs, "b (hh ph ww pw) d -> b (hh ww) (ph pw d)", hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1])
# Tokenize the input images
input_tokens = self.tokenizer(input)
return input_tokens
def forward(self, x, depths, imgs, img_info, conf=None):
B, V, H, W = img_info
input_tokens = self.img_pts_tokenizer(imgs)
# Hook decoder onto 4 layers from specified ViT layers
layer_tokens = [x[hook] for hook in self.hooks] # [B, S, D]
# layer_tokens.append(input_tokens)
x = self.token_decoder(torch.cat(layer_tokens, dim=-1))
x = x.view(B*V, (H // self.patch_size[0]) * (W // self.patch_size[1]), self.patch_size[0]**2, self.C_feat).flatten(1, 2).contiguous()
out_flat = self.pixel_linear(x)
return out_flat.view(B*V, H, W, -1).permute(0, 3, 1, 2)
# class Pixellevel_Linear_Pts3d(nn.Module):
# """
# Pixel-level linear head for DUST3R
# Each pixel outputs: 3D point (+ confidence)
# """
# def __init__(self, dec_embed_dim, patch_size, depth_mode, conf_mode, has_conf=False, index_hook=[-1]):
# super().__init__()
# self.patch_size = patch_size
# self.depth_mode = depth_mode
# self.conf_mode = conf_mode
# self.has_conf = has_conf
# self.dec_embed_dim = dec_embed_dim
# self.index_hook = index_hook
# # Total embedding dimension per token (possibly concatenated)
# D = self.dec_embed_dim * len(self.index_hook)
# # Ensure divisible into pixel-level features
# assert D % (self.patch_size**2) == 0, \
# f"Embedding dim {D} not divisible by patch_size^2 ({self.patch_size**2})"
# # Feature dimension for each pixel
# self.C_feat = D // (self.patch_size**2) * 4
# # Output channels: x,y,z (+ confidence)
# self.out_dim = 3 + int(self.has_conf)
# self.feat_expand = nn.Sequential(nn.Linear(D, 4*D),
# nn.SiLU(),
# nn.Linear(4*D, 4*D)
# )
# # Per-pixel linear head
# self.pixel_linear = nn.Linear(self.C_feat, self.out_dim)
# def setup(self, croconet):
# pass
# def forward(self, decout, img_shape):
# H, W = img_shape
# # Combine specified decoder tokens: B x num_patches x D
# tokens = [decout[i] for i in self.index_hook]
# x = torch.cat(tokens, dim=-1) # B, S, D
# x = self.feat_expand(x)
# B, S, D = x.shape
# # Validate pixel count
# assert S * (self.patch_size**2) == H * W, \
# f"Mismatch: S*ps^2 ({S*self.patch_size**2}) != H*W ({H*W})"
# # 1. Reshape embedding into pixel features
# # x -> B, S, (ps^2), C_feat -> flatten to B, (S*ps^2), C_feat
# x = x.view(B, S, self.patch_size**2, self.C_feat)
# x = x.reshape(B, S * self.patch_size**2, self.C_feat)
# # 2. Per-pixel linear output
# out_flat = self.pixel_linear(x) # B, S*ps^2, out_dim
# # 3. Reshape to image map: B x out_dim x H x W
# out = out_flat.permute(0, 2, 1).view(B, self.out_dim, H, W)
# # 4. Postprocess depth/conf
# return out
def create_gs_linear_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess):
"""
return PixelwiseTaskWithDPT for given net params
"""
assert net.dec_depth > 9
l2 = net.dec_depth
feature_dim = net.feature_dim
last_dim = feature_dim//2
ed = net.enc_embed_dim
dd = net.dec_embed_dim
try:
patch_size = net.patch_size
except:
patch_size = (16, 16)
return AttnBasedAppearanceHead(num_channels=out_nchan + has_conf,
patch_size=patch_size,
feature_dim=feature_dim,
last_dim=last_dim,
hooks_idx=[0, l2*2//4, l2*3//4, l2],
dim_tokens=[ed, dd, dd, dd],
postprocess=postprocess_func,
depth_mode=net.depth_mode,
conf_mode=net.conf_mode,
head_type='gs_params')