Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # linear head implementation for DUST3R | |
| # -------------------------------------------------------- | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .postprocess import postprocess | |
| class LinearPts3d (nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__(self, net, has_conf=False): | |
| super().__init__() | |
| self.patch_size = net.patch_embed.patch_size[0] | |
| self.depth_mode = net.depth_mode | |
| self.conf_mode = net.conf_mode | |
| self.has_conf = has_conf | |
| self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) | |
| def setup(self, croconet): | |
| pass | |
| def forward(self, decout, img_shape): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| B, S, D = tokens.shape | |
| # extract 3D points | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| # permute + norm depth | |
| return postprocess(feat, self.depth_mode, self.conf_mode) | |