vmem / extern /CUT3R /src /dust3r /heads /linear_head.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import torch
import torch.nn as nn
import torch.nn.functional as F
from dust3r.heads.postprocess import (
postprocess,
postprocess_desc,
postprocess_rgb,
postprocess_pose_conf,
postprocess_pose,
reg_dense_conf,
)
import dust3r.utils.path_to_croco # noqa
from models.blocks import Mlp # noqa
from dust3r.utils.geometry import geotrf
from dust3r.utils.camera import pose_encoding_to_camera, PoseDecoder
from dust3r.blocks import ConditionModulationBlock
class LinearPts3d(nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(
self, net, has_conf=False, has_depth=False, has_rgb=False, has_pose_conf=False
):
super().__init__()
self.patch_size = net.patch_embed.patch_size[0]
self.depth_mode = net.depth_mode
self.conf_mode = net.conf_mode
self.has_conf = has_conf
self.has_rgb = has_rgb
self.has_pose_conf = has_pose_conf
self.has_depth = has_depth
self.proj = Mlp(
net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
)
if has_depth:
self.self_proj = Mlp(
net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
)
if has_rgb:
self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2)
def setup(self, croconet):
pass
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
B, S, D = tokens.shape
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
final_output = postprocess(feat, self.depth_mode, self.conf_mode)
final_output["pts3d_in_other_view"] = final_output.pop("pts3d")
if self.has_depth:
self_feat = self.self_proj(tokens) # B,S,D
self_feat = self_feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
self_feat = F.pixel_shuffle(self_feat, self.patch_size) # B,3,H,W
self_3d_output = postprocess(self_feat, self.depth_mode, self.conf_mode)
self_3d_output["pts3d_in_self_view"] = self_3d_output.pop("pts3d")
self_3d_output["conf_self"] = self_3d_output.pop("conf")
final_output.update(self_3d_output)
if self.has_rgb:
rgb_feat = self.rgb_proj(tokens)
rgb_feat = rgb_feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W
rgb_output = postprocess_rgb(rgb_feat)
final_output.update(rgb_output)
if self.has_pose_conf:
pose_conf = self.pose_conf_proj(tokens)
pose_conf = pose_conf.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
pose_conf = F.pixel_shuffle(pose_conf, self.patch_size)
pose_conf_output = postprocess_pose_conf(pose_conf)
final_output.update(pose_conf_output)
return final_output
class LinearPts3d_Desc(nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(
self,
net,
has_conf=False,
has_depth=False,
local_feat_dim=24,
hidden_dim_factor=4.0,
):
super().__init__()
self.patch_size = net.patch_embed.patch_size[0]
self.depth_mode = net.depth_mode
self.conf_mode = net.conf_mode
self.has_conf = has_conf
self.double_channel = has_depth
self.local_feat_dim = local_feat_dim
if not has_depth:
self.proj = nn.Linear(
net.dec_embed_dim, (3 + has_conf) * self.patch_size**2
)
else:
self.proj = nn.Linear(
net.dec_embed_dim, (3 + has_conf) * 2 * self.patch_size**2
)
idim = net.enc_embed_dim + net.dec_embed_dim
self.head_local_features = Mlp(
in_features=idim,
hidden_features=int(hidden_dim_factor * idim),
out_features=(self.local_feat_dim + 1) * self.patch_size**2,
)
def setup(self, croconet):
pass
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
B, S, D = tokens.shape
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
enc_output, dec_output = decout[0], decout[-1]
cat_output = torch.cat([enc_output, dec_output], dim=-1)
local_features = self.head_local_features(cat_output) # B,S,D
local_features = local_features.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
feat = torch.cat([feat, local_features], dim=1)
return postprocess_desc(
feat,
self.depth_mode,
self.conf_mode,
self.local_feat_dim,
self.double_channel,
)
class LinearPts3dPoseDirect(nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(self, net, has_conf=False, has_rgb=False, has_pose=False):
super().__init__()
self.patch_size = net.patch_embed.patch_size[0]
self.depth_mode = net.depth_mode
self.conf_mode = net.conf_mode
self.pose_mode = net.pose_mode
self.has_conf = has_conf
self.has_rgb = has_rgb
self.has_pose = has_pose
self.proj = Mlp(
net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
)
if has_rgb:
self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2)
if has_pose:
self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim)
if has_conf:
self.cross_conf_proj = Mlp(
net.dec_embed_dim, out_features=self.patch_size**2
)
def setup(self, croconet):
pass
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
if self.has_pose:
pose_token = tokens[:, 0]
tokens = tokens[:, 1:]
B, S, D = tokens.shape
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
final_output = postprocess(feat, self.depth_mode, self.conf_mode)
final_output["pts3d_in_self_view"] = final_output.pop("pts3d")
final_output["conf_self"] = final_output.pop("conf")
if self.has_rgb:
rgb_feat = self.rgb_proj(tokens)
rgb_feat = rgb_feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W
rgb_output = postprocess_rgb(rgb_feat)
final_output.update(rgb_output)
if self.has_pose:
pose = self.pose_head(pose_token)
pose = postprocess_pose(pose, self.pose_mode)
final_output["camera_pose"] = pose # B,7
final_output["pts3d_in_other_view"] = geotrf(
pose_encoding_to_camera(final_output["camera_pose"]),
final_output["pts3d_in_self_view"],
)
if self.has_conf:
cross_conf = self.cross_conf_proj(tokens)
cross_conf = cross_conf.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
cross_conf = F.pixel_shuffle(cross_conf, self.patch_size)[:, 0]
final_output["conf"] = reg_dense_conf(cross_conf, mode=self.conf_mode)
return final_output
class LinearPts3dPose(nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(
self, net, has_conf=False, has_rgb=False, has_pose=False, mlp_ratio=4.0
):
super().__init__()
self.patch_size = net.patch_embed.patch_size[0]
self.depth_mode = net.depth_mode
self.conf_mode = net.conf_mode
self.pose_mode = net.pose_mode
self.has_conf = has_conf
self.has_rgb = has_rgb
self.has_pose = has_pose
self.proj = Mlp(
net.dec_embed_dim,
hidden_features=int(mlp_ratio * net.dec_embed_dim),
out_features=(3 + has_conf) * self.patch_size**2,
)
if has_rgb:
self.rgb_proj = Mlp(
net.dec_embed_dim,
hidden_features=int(mlp_ratio * net.dec_embed_dim),
out_features=3 * self.patch_size**2,
)
if has_pose:
self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim)
self.final_transform = nn.ModuleList(
[
ConditionModulationBlock(
net.dec_embed_dim,
net.dec_num_heads,
mlp_ratio=4.0,
qkv_bias=True,
rope=net.rope,
)
for _ in range(2)
]
)
self.cross_proj = Mlp(
net.dec_embed_dim,
hidden_features=int(mlp_ratio * net.dec_embed_dim),
out_features=(3 + has_conf) * self.patch_size**2,
)
def setup(self, croconet):
pass
def forward(self, decout, img_shape, **kwargs):
H, W = img_shape
tokens = decout[-1]
if self.has_pose:
pose_token = tokens[:, 0]
tokens = tokens[:, 1:]
with torch.cuda.amp.autocast(enabled=False):
pose = self.pose_head(pose_token)
cross_tokens = tokens
for blk in self.final_transform:
cross_tokens = blk(cross_tokens, pose_token, kwargs.get("pos"))
with torch.cuda.amp.autocast(enabled=False):
B, S, D = tokens.shape
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
final_output = postprocess(
feat, self.depth_mode, self.conf_mode, pos_z=True
)
final_output["pts3d_in_self_view"] = final_output.pop("pts3d")
final_output["conf_self"] = final_output.pop("conf")
if self.has_rgb:
rgb_feat = self.rgb_proj(tokens)
rgb_feat = rgb_feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W
rgb_output = postprocess_rgb(rgb_feat)
final_output.update(rgb_output)
if self.has_pose:
pose = postprocess_pose(pose, self.pose_mode)
final_output["camera_pose"] = pose # B,7
cross_feat = self.cross_proj(cross_tokens) # B,S,D
cross_feat = cross_feat.transpose(-1, -2).view(
B, -1, H // self.patch_size, W // self.patch_size
)
cross_feat = F.pixel_shuffle(cross_feat, self.patch_size) # B,3,H,W
tmp = postprocess(cross_feat, self.depth_mode, self.conf_mode)
final_output["pts3d_in_other_view"] = tmp.pop("pts3d")
final_output["conf"] = tmp.pop("conf")
return final_output