Spaces:
Running
Running
| """Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| Dense Prediction Transformer Decoder architecture. | |
| Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 | |
| """ | |
| from __future__ import annotations | |
| from typing import Iterable | |
| import torch | |
| from torch import nn | |
| class MultiresConvDecoder(nn.Module): | |
| """Decoder for multi-resolution encodings.""" | |
| def __init__( | |
| self, | |
| dims_encoder: Iterable[int], | |
| dim_decoder: int, | |
| ): | |
| """Initialize multiresolution convolutional decoder. | |
| Args: | |
| ---- | |
| dims_encoder: Expected dims at each level from the encoder. | |
| dim_decoder: Dim of decoder features. | |
| """ | |
| super().__init__() | |
| self.dims_encoder = list(dims_encoder) | |
| self.dim_decoder = dim_decoder | |
| self.dim_out = dim_decoder | |
| num_encoders = len(self.dims_encoder) | |
| # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution | |
| # when the dimensions mismatch. Otherwise we do not do anything, which is | |
| # the default behavior of monodepth. | |
| conv0 = ( | |
| nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False) | |
| if self.dims_encoder[0] != dim_decoder | |
| else nn.Identity() | |
| ) | |
| convs = [conv0] | |
| for i in range(1, num_encoders): | |
| convs.append( | |
| nn.Conv2d( | |
| self.dims_encoder[i], | |
| dim_decoder, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ) | |
| ) | |
| self.convs = nn.ModuleList(convs) | |
| fusions = [] | |
| for i in range(num_encoders): | |
| fusions.append( | |
| FeatureFusionBlock2d( | |
| num_features=dim_decoder, | |
| deconv=(i != 0), | |
| batch_norm=False, | |
| ) | |
| ) | |
| self.fusions = nn.ModuleList(fusions) | |
| def forward(self, encodings: torch.Tensor) -> torch.Tensor: | |
| """Decode the multi-resolution encodings.""" | |
| num_levels = len(encodings) | |
| num_encoders = len(self.dims_encoder) | |
| if num_levels != num_encoders: | |
| raise ValueError( | |
| f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}." | |
| ) | |
| # Project features of different encoder dims to the same decoder dim. | |
| # Fuse features from the lowest resolution (num_levels-1) | |
| # to the highest (0). | |
| features = self.convs[-1](encodings[-1]) | |
| lowres_features = features | |
| features = self.fusions[-1](features) | |
| for i in range(num_levels - 2, -1, -1): | |
| features_i = self.convs[i](encodings[i]) | |
| features = self.fusions[i](features, features_i) | |
| return features, lowres_features | |
| class ResidualBlock(nn.Module): | |
| """Generic implementation of residual blocks. | |
| This implements a generic residual block from | |
| He et al. - Identity Mappings in Deep Residual Networks (2016), | |
| https://arxiv.org/abs/1603.05027 | |
| which can be further customized via factory functions. | |
| """ | |
| def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None: | |
| """Initialize ResidualBlock.""" | |
| super().__init__() | |
| self.residual = residual | |
| self.shortcut = shortcut | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply residual block.""" | |
| delta_x = self.residual(x) | |
| if self.shortcut is not None: | |
| x = self.shortcut(x) | |
| return x + delta_x | |
| class FeatureFusionBlock2d(nn.Module): | |
| """Feature fusion for DPT.""" | |
| def __init__( | |
| self, | |
| num_features: int, | |
| deconv: bool = False, | |
| batch_norm: bool = False, | |
| ): | |
| """Initialize feature fusion block. | |
| Args: | |
| ---- | |
| num_features: Input and output dimensions. | |
| deconv: Whether to use deconv before the final output conv. | |
| batch_norm: Whether to use batch normalization in resnet blocks. | |
| """ | |
| super().__init__() | |
| self.resnet1 = self._residual_block(num_features, batch_norm) | |
| self.resnet2 = self._residual_block(num_features, batch_norm) | |
| self.use_deconv = deconv | |
| if deconv: | |
| self.deconv = nn.ConvTranspose2d( | |
| in_channels=num_features, | |
| out_channels=num_features, | |
| kernel_size=2, | |
| stride=2, | |
| padding=0, | |
| bias=False, | |
| ) | |
| self.out_conv = nn.Conv2d( | |
| num_features, | |
| num_features, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ) | |
| self.skip_add = nn.quantized.FloatFunctional() | |
| def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor: | |
| """Process and fuse input features.""" | |
| x = x0 | |
| if x1 is not None: | |
| res = self.resnet1(x1) | |
| x = self.skip_add.add(x, res) | |
| x = self.resnet2(x) | |
| if self.use_deconv: | |
| x = self.deconv(x) | |
| x = self.out_conv(x) | |
| return x | |
| def _residual_block(num_features: int, batch_norm: bool): | |
| """Create a residual block.""" | |
| def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]: | |
| layers = [ | |
| nn.ReLU(False), | |
| nn.Conv2d( | |
| num_features, | |
| num_features, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=not batch_norm, | |
| ), | |
| ] | |
| if batch_norm: | |
| layers.append(nn.BatchNorm2d(dim)) | |
| return layers | |
| residual = nn.Sequential( | |
| *_create_block(dim=num_features, batch_norm=batch_norm), | |
| *_create_block(dim=num_features, batch_norm=batch_norm), | |
| ) | |
| return ResidualBlock(residual) | |