|
from dataclasses import dataclass |
|
from typing import Literal |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
from jaxtyping import Float |
|
from torch import Tensor, nn |
|
|
|
from src.dataset.types import BatchedViews |
|
from .backbone import Backbone |
|
from .backbone_resnet import BackboneResnet, BackboneResnetCfg |
|
|
|
|
|
@dataclass |
|
class BackboneDinoCfg: |
|
name: Literal["dino"] |
|
model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"] |
|
d_out: int |
|
|
|
|
|
class BackboneDino(Backbone[BackboneDinoCfg]): |
|
def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None: |
|
super().__init__(cfg) |
|
assert d_in == 3 |
|
self.dino = torch.hub.load("facebookresearch/dino:main", cfg.model) |
|
self.resnet_backbone = BackboneResnet( |
|
BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out), |
|
d_in, |
|
) |
|
self.global_token_mlp = nn.Sequential( |
|
nn.Linear(768, 768), |
|
nn.ReLU(), |
|
nn.Linear(768, cfg.d_out), |
|
) |
|
self.local_token_mlp = nn.Sequential( |
|
nn.Linear(768, 768), |
|
nn.ReLU(), |
|
nn.Linear(768, cfg.d_out), |
|
) |
|
|
|
def forward( |
|
self, |
|
context: BatchedViews, |
|
) -> Float[Tensor, "batch view d_out height width"]: |
|
|
|
resnet_features = self.resnet_backbone(context) |
|
|
|
|
|
b, v, _, h, w = context["image"].shape |
|
assert h % self.patch_size == 0 and w % self.patch_size == 0 |
|
tokens = rearrange(context["image"], "b v c h w -> (b v) c h w") |
|
tokens = self.dino.get_intermediate_layers(tokens)[0] |
|
global_token = self.global_token_mlp(tokens[:, 0]) |
|
local_tokens = self.local_token_mlp(tokens[:, 1:]) |
|
|
|
|
|
global_token = repeat(global_token, "(b v) c -> b v c h w", b=b, v=v, h=h, w=w) |
|
|
|
|
|
local_tokens = repeat( |
|
local_tokens, |
|
"(b v) (h w) c -> b v c (h hps) (w wps)", |
|
b=b, |
|
v=v, |
|
h=h // self.patch_size, |
|
hps=self.patch_size, |
|
w=w // self.patch_size, |
|
wps=self.patch_size, |
|
) |
|
|
|
return resnet_features + local_tokens + global_token |
|
|
|
@property |
|
def patch_size(self) -> int: |
|
return int("".join(filter(str.isdigit, self.cfg.model))) |
|
|
|
@property |
|
def d_out(self) -> int: |
|
return self.cfg.d_out |
|
|