File size: 3,048 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import functools
from dataclasses import dataclass
from typing import Literal
import torch
import torch.nn.functional as F
import torchvision
from einops import rearrange
from jaxtyping import Float
from torch import Tensor, nn
from torchvision.models import ResNet
from src.dataset.types import BatchedViews
from .backbone import Backbone
@dataclass
class BackboneResnetCfg:
name: Literal["resnet"]
model: Literal[
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50"
]
num_layers: int
use_first_pool: bool
d_out: int
class BackboneResnet(Backbone[BackboneResnetCfg]):
model: ResNet
def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None:
super().__init__(cfg)
assert d_in == 3
norm_layer = functools.partial(
nn.InstanceNorm2d,
affine=False,
track_running_stats=False,
)
if cfg.model == "dino_resnet50":
self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50")
else:
self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer)
# Set up projections
self.projections = nn.ModuleDict({})
for index in range(1, cfg.num_layers):
key = f"layer{index}"
block = getattr(self.model, key)
conv_index = 1
try:
while True:
d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels
conv_index += 1
except AttributeError:
pass
self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1)
# Add a projection for the first layer.
self.projections["layer0"] = nn.Conv2d(
self.model.conv1.out_channels, cfg.d_out, 1
)
def forward(
self,
context: BatchedViews,
) -> Float[Tensor, "batch view d_out height width"]:
# Merge the batch dimensions.
b, v, _, h, w = context["image"].shape
x = rearrange(context["image"], "b v c h w -> (b v) c h w")
# Run the images through the resnet.
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
features = [self.projections["layer0"](x)]
# Propagate the input through the resnet's layers.
for index in range(1, self.cfg.num_layers):
key = f"layer{index}"
if index == 0 and self.cfg.use_first_pool:
x = self.model.maxpool(x)
x = getattr(self.model, key)(x)
features.append(self.projections[key](x))
# Upscale the features.
features = [
F.interpolate(f, (h, w), mode="bilinear", align_corners=True)
for f in features
]
features = torch.stack(features).sum(dim=0)
# Separate batch dimensions.
return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v)
@property
def d_out(self) -> int:
return self.cfg.d_out
|