# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import torch.nn as nn import torch.nn.functional as F class Interpolate(nn.Module): # TODO: Not sure about normalization generally (norm_cfg in original code) def __init__(self, scale_factor, mode, align_corners=False): super().__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class ReassembleBlocks(nn.Module): """ViTPostProcessBlock, process cls_token in ViT backbone output and rearrange the feature vector to feature map. Args: in_channels (int): ViT feature channels. Default: 768. out_channels (List): output channels of each stage. Default: [96, 192, 384, 768]. readout_type (str): Type of readout operation. Default: 'ignore'. patch_size (int): The patch size. Default: 16. """ def __init__( self, in_channels=768, out_channels=None, readout_type="ignore", patch_size=16 ): super().__init__() if out_channels is None: out_channels = [96, 192, 384, 384] assert readout_type in ["ignore"] # ["ignore", "add", "project"] self.readout_type = readout_type self.patch_size = patch_size self.projects = nn.ModuleList( [ nn.Conv2d(in_channels, out_channel, kernel_size=1) for out_channel in out_channels ] ) self.resize_layers = nn.ModuleList( [ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 ), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 ), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 ) ] ) def forward(self, inputs): out = [] for i, x in enumerate(inputs): x = self.projects[i](x) x = self.resize_layers[i](x) out.append(x) return out class PreActResidualConvUnit(nn.Module): """ResidualConvUnit, pre-activate residual unit. Args: in_channels (int): number of channels in the input feature map. stride (int): stride of the first block. Default: 1 dilation (int): dilation rate for convs layers. Default: 1. """ def __init__(self, in_channels, stride=1, dilation=1, bn=False): super().__init__() self.bn = bn self.act = nn.ReLU() self.conv1 = nn.Conv2d( in_channels, in_channels, 3, stride=stride, padding=dilation, dilation=dilation, bias=not self.bn, ) self.conv2 = nn.Conv2d( in_channels, in_channels, 3, padding=1, bias=not self.bn, ) if self.bn: self.bn1 = nn.BatchNorm2d(in_channels) self.bn2 = nn.BatchNorm2d(in_channels) def forward(self, inputs): inputs_ = inputs.clone() x = self.act(inputs) x = self.conv1(x) if self.bn: x = self.bn1(x) x = self.act(x) x = self.conv2(x) if self.bn: x = self.bn2(x) return x + inputs_ class FeatureFusionBlock(nn.Module): """FeatureFusionBlock, merge feature map from different stages. Args: in_channels (int): Input channels. expand (bool): Whether expand the channels in post process block. Default: False. align_corners (bool): align_corner setting for bilinear upsample. Default: True. """ def __init__(self, in_channels, expand=False, align_corners=True): super().__init__() self.in_channels = in_channels self.expand = expand self.align_corners = align_corners self.out_channels = in_channels if self.expand: self.out_channels = in_channels // 2 self.project = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1) self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels) self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels) def forward(self, *inputs): x = inputs[0] if len(inputs) == 2: if x.shape != inputs[1].shape: res = F.interpolate(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) else: res = inputs[1] x = x + self.res_conv_unit1(res) x = self.res_conv_unit2(x) x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) x = self.project(x) return x class OutputHead(nn.Module): def __init__(self, latent_size=768): super().__init__() # TODO: Not sure about structure self.head_modules = nn.ModuleList( [ nn.Conv2d(in_channels=latent_size, out_channels=latent_size, kernel_size=3, stride=1, padding=1), nn.ConvTranspose2d(in_channels=latent_size, out_channels=latent_size, kernel_size=2, stride=2, padding=0), nn.Conv2d(in_channels=latent_size, out_channels=latent_size, kernel_size=3, stride=1, padding=1), ] ) def forward(self, x): for module in self.head_modules: x = module(x) return x class DPTHead(nn.Module): """Vision Transformers for Dense Prediction. This head is implemented of `DPT `_. Args: embed_dims (int): The embed dimension of the ViT backbone. Default: 768. post_process_channels (List): Out channels of post process conv layers. Default: [96, 192, 384, 768]. readout_type (str): Type of readout operation. Default: 'ignore'. patch_size (int): The patch size. Default: 16. expand_channels (bool): Whether expand the channels in post process block. Default: False. """ def __init__( self, embed_dims=768, post_process_channels=None, readout_type="ignore", patch_size=16, d_out=384, expand_channels=False, ): super().__init__() if not post_process_channels: self.post_process_channels = [96, 192, 384, 768] self.post_process_channels = [min(d_out, channel) for channel in post_process_channels] self.d_out = d_out self.expand_channels = expand_channels self.reassemble_blocks = ReassembleBlocks(embed_dims, self.post_process_channels, readout_type, patch_size) self.convs = nn.ModuleList() for channel in self.post_process_channels: self.convs.append(nn.Conv2d(channel, self.d_out, kernel_size=3, padding=1, bias=False)) self.fusion_blocks = nn.ModuleList() for _ in range(len(self.convs)): self.fusion_blocks.append(FeatureFusionBlock(self.d_out)) self.fusion_blocks[0].res_conv_unit1 = None self.project = nn.Conv2d(self.d_out, self.d_out, kernel_size=3, padding=1) self.output_head = OutputHead(d_out) self.num_fusion_blocks = len(self.fusion_blocks) self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) self.num_post_process_channels = len(self.post_process_channels) assert self.num_fusion_blocks == self.num_reassemble_blocks assert self.num_reassemble_blocks == self.num_post_process_channels def forward(self, inputs): assert len(inputs) == self.num_reassemble_blocks x = [inp for inp in inputs] x = self.reassemble_blocks(x) x = [self.convs[i](feature) for i, feature in enumerate(x)] out = self.fusion_blocks[0](x[-1]) for i in range(1, len(self.fusion_blocks)): out = self.fusion_blocks[i](out, x[-(i + 1)]) out = self.project(out) out = self.output_head(out) return [out] # list for BTS