import functools from dataclasses import dataclass from typing import Literal import torch import torch.nn.functional as F import torchvision from einops import rearrange from jaxtyping import Float from torch import Tensor, nn from torchvision.models import ResNet from src.dataset.types import BatchedViews from .backbone import Backbone @dataclass class BackboneResnetCfg: name: Literal["resnet"] model: Literal[ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50" ] num_layers: int use_first_pool: bool d_out: int class BackboneResnet(Backbone[BackboneResnetCfg]): model: ResNet def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None: super().__init__(cfg) assert d_in == 3 norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False, ) if cfg.model == "dino_resnet50": self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") else: self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer) # Set up projections self.projections = nn.ModuleDict({}) for index in range(1, cfg.num_layers): key = f"layer{index}" block = getattr(self.model, key) conv_index = 1 try: while True: d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels conv_index += 1 except AttributeError: pass self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1) # Add a projection for the first layer. self.projections["layer0"] = nn.Conv2d( self.model.conv1.out_channels, cfg.d_out, 1 ) def forward( self, context: BatchedViews, ) -> Float[Tensor, "batch view d_out height width"]: # Merge the batch dimensions. b, v, _, h, w = context["image"].shape x = rearrange(context["image"], "b v c h w -> (b v) c h w") # Run the images through the resnet. x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) features = [self.projections["layer0"](x)] # Propagate the input through the resnet's layers. for index in range(1, self.cfg.num_layers): key = f"layer{index}" if index == 0 and self.cfg.use_first_pool: x = self.model.maxpool(x) x = getattr(self.model, key)(x) features.append(self.projections[key](x)) # Upscale the features. features = [ F.interpolate(f, (h, w), mode="bilinear", align_corners=True) for f in features ] features = torch.stack(features).sum(dim=0) # Separate batch dimensions. return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v) @property def d_out(self) -> int: return self.cfg.d_out