Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,543 Bytes
853528a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from .attention import FlashAttentionRope
from .block import BlockRope
from ..dinov2.layers import Mlp
import torch.nn as nn
from functools import partial
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
class TransformerDecoder(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dec_embed_dim=512,
depth=5,
dec_num_heads=8,
mlp_ratio=4,
rope=None,
need_project=True,
use_checkpoint=False,
):
super().__init__()
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
BlockRope(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=None,
qk_norm=False,
# attn_class=MemEffAttentionRope,
attn_class=FlashAttentionRope,
rope=rope
) for _ in range(depth)])
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
def forward(self, hidden, xpos=None):
hidden = self.projects(hidden)
for i, blk in enumerate(self.blocks):
if self.use_checkpoint and self.training:
hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
else:
hidden = blk(hidden, xpos=xpos)
out = self.linear_out(hidden)
return out
class LinearPts3d (nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
B, S, D = tokens.shape
# extract 3D points
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
# permute + norm depth
return feat.permute(0, 2, 3, 1) |